feat(admin): 添加临时不可调度功能

当账号触发特定错误码和关键词匹配时,自动临时禁用调度:

后端:
- 新增 TempUnschedCache Redis 缓存层
- RateLimitService 支持规则匹配和状态管理
- 添加 GET/DELETE /accounts/:id/temp-unschedulable API
- 数据库迁移添加 temp_unschedulable_until/reason 字段

前端:
- 账号状态指示器显示临时不可调度状态
- 新增 TempUnschedStatusModal 详情弹窗
- 创建/编辑账号时支持配置规则和预设模板
- 完整的中英文国际化支持
This commit is contained in:
ianshaw
2026-01-03 06:34:00 -08:00
parent acb718d355
commit 09da6904f5
18 changed files with 1829 additions and 199 deletions

View File

@@ -1,9 +1,11 @@
package admin
import (
"errors"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -69,42 +71,45 @@ func NewAccountHandler(
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
Name string `json:"name"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// AccountWithConcurrency extends Account with real-time concurrency info
@@ -179,18 +184,40 @@ func (h *AccountHandler) Create(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
response.ErrorFrom(c, err)
return
}
@@ -213,18 +240,40 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
Name: req.Name,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
response.ErrorFrom(c, err)
return
}
@@ -568,6 +617,9 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
return
}
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
hasUpdates := req.Name != "" ||
req.ProxyID != nil ||
req.Concurrency != nil ||
@@ -583,15 +635,16 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
}
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
AccountIDs: req.AccountIDs,
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -781,6 +834,49 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
}
// GetTempUnschedulable handles getting temporary unschedulable status
// GET /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) GetTempUnschedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
state, err := h.rateLimitService.GetTempUnschedStatus(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if state == nil || state.UntilUnix <= time.Now().Unix() {
response.Success(c, gin.H{"active": false})
return
}
response.Success(c, gin.H{
"active": true,
"state": state,
})
}
// ClearTempUnschedulable handles clearing temporary unschedulable status
// DELETE /api/v1/admin/accounts/:id/temp-unschedulable
func (h *AccountHandler) ClearTempUnschedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
if err := h.rateLimitService.ClearTempUnschedulable(c.Request.Context(), accountID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Temp unschedulable cleared successfully"})
}
// GetTodayStats handles getting account today statistics
// GET /api/v1/admin/accounts/:id/today-stats
func (h *AccountHandler) GetTodayStats(c *gin.Context) {

View File

@@ -1,4 +1,3 @@
// Package dto provides mapping utilities for converting between service layer and HTTP handler DTOs.
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
@@ -27,11 +26,11 @@ func UserFromService(u *service.User) *User {
return nil
}
out := UserFromServiceShallow(u)
if len(u.APIKeys) > 0 {
out.APIKeys = make([]APIKey, 0, len(u.APIKeys))
for i := range u.APIKeys {
k := u.APIKeys[i]
out.APIKeys = append(out.APIKeys, *APIKeyFromService(&k))
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
@@ -44,11 +43,11 @@ func UserFromService(u *service.User) *User {
return out
}
func APIKeyFromService(k *service.APIKey) *APIKey {
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
if k == nil {
return nil
}
return &APIKey{
return &ApiKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
@@ -104,28 +103,30 @@ func AccountFromServiceShallow(a *service.Account) *Account {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
TempUnschedulableUntil: a.TempUnschedulableUntil,
TempUnschedulableReason: a.TempUnschedulableReason,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
@@ -221,7 +222,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
ApiKeyID: l.ApiKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
@@ -246,7 +247,7 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
ApiKey: ApiKeyFromService(l.ApiKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),

View File

@@ -15,11 +15,11 @@ type User struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
APIKeys []APIKey `json:"api_keys,omitempty"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type APIKey struct {
type ApiKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
@@ -76,6 +76,9 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"`
TempUnschedulableReason string `json:"temp_unschedulable_reason"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
@@ -136,7 +139,7 @@ type RedeemCode struct {
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -168,7 +171,7 @@ type UsageLog struct {
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
APIKey *APIKey `json:"api_key,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`

View File

@@ -43,6 +43,11 @@ type accountRepository struct {
sql sqlExecutor // 原生 SQL 执行接口
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB) service.AccountRepository {
@@ -165,6 +170,11 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -191,6 +201,10 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
@@ -550,6 +564,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
Where(
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -575,6 +590,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
dbaccount.PlatformEQ(platform),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -607,6 +623,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
dbaccount.PlatformIn(platforms...),
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -648,6 +665,31 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
return err
}
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = $1,
temp_unschedulable_reason = $2,
updated_at = NOW()
WHERE id = $3
AND deleted_at IS NULL
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
`, until, reason, id)
return err
}
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET temp_unschedulable_until = NULL,
temp_unschedulable_reason = NULL,
updated_at = NOW()
WHERE id = $1
AND deleted_at IS NULL
`, id)
return err
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -808,6 +850,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
now := time.Now()
preds = append(preds,
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
@@ -869,6 +912,10 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -894,12 +941,68 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
return outAccounts, nil
}
func tempUnschedulablePredicate() dbpredicate.Account {
return dbpredicate.Account(func(s *entsql.Selector) {
col := s.C("temp_unschedulable_until")
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.LTE(col, entsql.Expr("NOW()")),
))
})
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {

View File

@@ -0,0 +1,91 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const tempUnschedPrefix = "temp_unsched:account:"
var tempUnschedSetScript = redis.NewScript(`
local key = KEYS[1]
local new_until = tonumber(ARGV[1])
local new_value = ARGV[2]
local new_ttl = tonumber(ARGV[3])
local existing = redis.call('GET', key)
if existing then
local ok, existing_data = pcall(cjson.decode, existing)
if ok and existing_data and existing_data.until_unix then
local existing_until = tonumber(existing_data.until_unix)
if existing_until and new_until <= existing_until then
return 0
end
end
end
redis.call('SET', key, new_value, 'EX', new_ttl)
return 1
`)
type tempUnschedCache struct {
rdb *redis.Client
}
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
return &tempUnschedCache{rdb: rdb}
}
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
stateJSON, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state: %w", err)
}
ttl := time.Until(time.Unix(state.UntilUnix, 0))
if ttl <= 0 {
return nil // 已过期,不设置
}
ttlSeconds := int(ttl.Seconds())
if ttlSeconds < 1 {
ttlSeconds = 1
}
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
return err
}
// GetTempUnsched 获取临时不可调度状态
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var state service.TempUnschedState
if err := json.Unmarshal([]byte(val), &state); err != nil {
return nil, fmt.Errorf("unmarshal state: %w", err)
}
return &state, nil
}
// DeleteTempUnsched 删除临时不可调度状态
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -19,9 +19,6 @@ func RegisterAdminRoutes(
// 仪表盘
registerDashboardRoutes(admin, h)
// 运维监控
registerOpsRoutes(admin, h)
// 用户管理
registerUserManagementRoutes(admin, h)
@@ -70,35 +67,10 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
}
}
func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops := admin.Group("/ops")
{
ops.GET("/metrics", h.Admin.Ops.GetMetrics)
ops.GET("/metrics/history", h.Admin.Ops.ListMetricsHistory)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/error-logs", h.Admin.Ops.ListErrorLogs)
// Dashboard routes
dashboard := ops.Group("/dashboard")
{
dashboard.GET("/overview", h.Admin.Ops.GetDashboardOverview)
dashboard.GET("/providers", h.Admin.Ops.GetProviderHealth)
dashboard.GET("/latency-histogram", h.Admin.Ops.GetLatencyHistogram)
dashboard.GET("/errors/distribution", h.Admin.Ops.GetErrorDistribution)
}
// WebSocket routes
ws := ops.Group("/ws")
{
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
}
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
}
@@ -151,6 +123,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
@@ -231,12 +205,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSMTPConnection)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
}
@@ -276,7 +250,7 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -19,7 +20,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
@@ -30,7 +31,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
@@ -65,7 +66,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// CreateUserInput represents the input for creating a new user
// Input types for admin operations
type CreateUserInput struct {
Email string
Password string
@@ -122,18 +123,22 @@ type CreateAccountInput struct {
Concurrency int
Priority int
GroupIDs []int64
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct {
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
// BulkUpdateAccountResult captures the result for a single account update.
@@ -220,7 +228,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
apiKeyRepo ApiKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
@@ -232,7 +240,7 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
apiKeyRepo ApiKeyRepository,
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{
Name: input.Name,
Platform: input.Platform,
@@ -636,22 +667,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
break
}
}
}
}
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err
@@ -702,6 +717,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return nil, fmt.Errorf("get group: %w", err)
}
}
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
if err := s.accountRepo.Update(ctx, account); err != nil {
@@ -730,6 +752,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
@@ -761,6 +797,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -1005,3 +1064,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
Country: exitInfo.Country,
}, nil
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
@@ -18,6 +19,7 @@ type RateLimitService struct {
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{
accountRepo: accountRepo,
usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
tempUnschedCache: tempUnschedCache,
usageCache: make(map[int64]*geminiUsageCacheEntry),
}
}
@@ -51,32 +54,39 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false
}
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true
shouldDisable = true
case 402:
// 支付要求:余额不足或计费问题,停止调度
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
return true
shouldDisable = true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true
shouldDisable = true
case 429:
s.handle429(ctx, account, headers)
return false
shouldDisable = false
case 529:
s.handle529(ctx, account)
return false
shouldDisable = false
default:
// 其他5xx错误记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
}
return false
shouldDisable = false
}
if tempMatched {
return true
}
return shouldDisable
}
// PreCheckUsage proactively checks local quota before dispatching a request.
@@ -287,3 +297,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.accountRepo.ClearRateLimit(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
return err
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
}
}
return nil
}
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
now := time.Now().Unix()
if s.tempUnschedCache != nil {
state, err := s.tempUnschedCache.GetTempUnsched(ctx, accountID)
if err != nil {
return nil, err
}
if state != nil && state.UntilUnix > now {
return state, nil
}
}
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, err
}
if account.TempUnschedulableUntil == nil {
return nil, nil
}
if account.TempUnschedulableUntil.Unix() <= now {
return nil, nil
}
state := &TempUnschedState{
UntilUnix: account.TempUnschedulableUntil.Unix(),
}
if account.TempUnschedulableReason != "" {
var parsed TempUnschedState
if err := json.Unmarshal([]byte(account.TempUnschedulableReason), &parsed); err == nil {
if parsed.UntilUnix == 0 {
parsed.UntilUnix = state.UntilUnix
}
state = &parsed
} else {
state.ErrorMessage = account.TempUnschedulableReason
}
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
}
}
return state, nil
}
func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.ShouldHandleErrorCode(statusCode) {
return false
}
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
const tempUnschedBodyMaxBytes = 64 << 10
const tempUnschedMessageMaxBytes = 2048
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.IsTempUnschedulableEnabled() {
return false
}
rules := account.GetTempUnschedulableRules()
if len(rules) == 0 {
return false
}
if statusCode <= 0 || len(responseBody) == 0 {
return false
}
body := responseBody
if len(body) > tempUnschedBodyMaxBytes {
body = body[:tempUnschedBodyMaxBytes]
}
bodyLower := strings.ToLower(string(body))
for idx, rule := range rules {
if rule.ErrorCode != statusCode || len(rule.Keywords) == 0 {
continue
}
matchedKeyword := matchTempUnschedKeyword(bodyLower, rule.Keywords)
if matchedKeyword == "" {
continue
}
if s.triggerTempUnschedulable(ctx, account, rule, idx, statusCode, matchedKeyword, responseBody) {
return true
}
}
return false
}
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
if bodyLower == "" {
return ""
}
for _, keyword := range keywords {
k := strings.TrimSpace(keyword)
if k == "" {
continue
}
if strings.Contains(bodyLower, strings.ToLower(k)) {
return k
}
}
return ""
}
func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account *Account, rule TempUnschedulableRule, ruleIndex int, statusCode int, matchedKeyword string, responseBody []byte) bool {
if account == nil {
return false
}
if rule.DurationMinutes <= 0 {
return false
}
now := time.Now()
until := now.Add(time.Duration(rule.DurationMinutes) * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: statusCode,
MatchedKeyword: matchedKeyword,
RuleIndex: ruleIndex,
ErrorMessage: truncateTempUnschedMessage(responseBody, tempUnschedMessageMaxBytes),
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = strings.TrimSpace(state.ErrorMessage)
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
return true
}
func truncateTempUnschedMessage(body []byte, maxBytes int) string {
if maxBytes <= 0 || len(body) == 0 {
return ""
}
if len(body) > maxBytes {
body = body[:maxBytes]
}
return strings.TrimSpace(string(body))
}

View File

@@ -0,0 +1,22 @@
package service
import (
"context"
)
// TempUnschedState 临时不可调度状态
type TempUnschedState struct {
UntilUnix int64 `json:"until_unix"` // 解除时间Unix 时间戳)
TriggeredAtUnix int64 `json:"triggered_at_unix"` // 触发时间Unix 时间戳)
StatusCode int `json:"status_code"` // 触发的错误码
MatchedKeyword string `json:"matched_keyword"` // 匹配的关键词
RuleIndex int `json:"rule_index"` // 触发的规则索引
ErrorMessage string `json:"error_message"` // 错误消息
}
// TempUnschedCache 临时不可调度缓存接口
type TempUnschedCache interface {
SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error
GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error)
DeleteTempUnsched(ctx context.Context, accountID int64) error
}

View File

@@ -0,0 +1,15 @@
-- 020_add_temp_unschedulable.sql
-- 添加临时不可调度功能相关字段
-- 添加临时不可调度状态解除时间字段
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS temp_unschedulable_until timestamptz;
-- 添加临时不可调度原因字段(用于排障和审计)
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS temp_unschedulable_reason text;
-- 添加索引以优化调度查询性能
CREATE INDEX IF NOT EXISTS idx_accounts_temp_unschedulable_until ON accounts(temp_unschedulable_until) WHERE deleted_at IS NULL;
-- 添加注释说明字段用途
COMMENT ON COLUMN accounts.temp_unschedulable_until IS '临时不可调度状态解除时间,当触发临时不可调度规则时设置(基于错误码或错误描述关键词)';
COMMENT ON COLUMN accounts.temp_unschedulable_reason IS '临时不可调度原因,记录触发临时不可调度的具体原因(用于排障和审计)';