Merge upstream/main

This commit is contained in:
song
2026-01-17 18:00:07 +08:00
394 changed files with 76872 additions and 1877 deletions

View File

@@ -9,16 +9,19 @@ import (
)
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
// RateMultiplier 账号计费倍率(>=0允许 0 表示该账号计费为 0
// 使用指针用于兼容旧版本调度缓存Redis中缺字段的情况nil 表示按 1.0 处理。
RateMultiplier *float64
Status string
ErrorMessage string
LastUsedAt *time.Time
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func (a *Account) BillingRateMultiplier() float64 {
if a == nil || a.RateMultiplier == nil {
return 1.0
}
if *a.RateMultiplier < 0 {
return 1.0
}
return *a.RateMultiplier
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
@@ -540,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return false
}
// WindowCostSchedulability 窗口费用调度状态
type WindowCostSchedulability int
const (
// WindowCostSchedulable 可正常调度
WindowCostSchedulable WindowCostSchedulability = iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["window_cost_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func (a *Account) GetWindowCostStickyReserve() float64 {
if a.Extra == nil {
return 10.0
}
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
val := parseExtraFloat64(v)
if val > 0 {
return val
}
}
return 10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func (a *Account) GetMaxSessions() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["max_sessions"]; ok {
return parseExtraInt(v)
}
return 0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func (a *Account) GetSessionIdleTimeoutMinutes() int {
if a.Extra == nil {
return 5
}
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable可正常调度
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly仅粘性会话
// - 费用 >= 阈值+预留: WindowCostNotSchedulable不可调度
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
limit := a.GetWindowCostLimit()
if limit <= 0 {
return WindowCostSchedulable
}
if currentWindowCost < limit {
return WindowCostSchedulable
}
stickyReserve := a.GetWindowCostStickyReserve()
if currentWindowCost < limit+stickyReserve {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
if f, err := v.Float64(); err == nil {
return f
}
case string:
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
return f
}
}
return 0
}
// parseExtraInt 从 extra 字段解析 int 值
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}

View File

@@ -0,0 +1,27 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
var a Account
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
require.Nil(t, a.RateMultiplier)
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
v := 0.0
a := Account{RateMultiplier: &v}
require.Equal(t, 0.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
v := -1.0
a := Account{RateMultiplier: &v}
require.Equal(t, 1.0, a.BillingRateMultiplier())
}

View File

@@ -51,11 +51,13 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
ClearModelRateLimits(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
@@ -64,14 +66,15 @@ type AccountRepository interface {
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
Name *string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
}
// CreateAccountRequest 创建账号请求

View File

@@ -147,6 +147,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
panic("unexpected SetModelRateLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call")
}
@@ -167,6 +171,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
panic("unexpected ClearModelRateLimits call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call")
}

View File

@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
}
// WindowStats 窗口期统计
//
// cost: 账号口径费用total_cost * account_rate_multiplier
// standard_cost: 标准费用total_cost不含倍率
// user_cost: 用户/API Key 口径费用actual_cost受分组倍率影响
type WindowStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
}
// UsageProgress 使用量进度
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
}
windowStats = &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}
// 缓存窗口统计1 分钟)
@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}, nil
}
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
}

View File

@@ -55,7 +55,8 @@ type AdminService interface {
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
@@ -106,6 +107,9 @@ type CreateGroupInput struct {
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
}
type UpdateGroupInput struct {
@@ -125,6 +129,9 @@ type UpdateGroupInput struct {
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
}
type CreateAccountInput struct {
@@ -137,6 +144,7 @@ type CreateAccountInput struct {
ProxyID *int64
Concurrency int
Priority int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
@@ -152,8 +160,9 @@ type UpdateAccountInput struct {
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
Status string
GroupIDs *[]int64
ExpiresAt *int64
@@ -163,16 +172,17 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
Status string
Schedulable *bool
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
AccountIDs []int64
Name string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64 // 账号计费倍率(>=0允许 0
Status string
Schedulable *bool
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -187,9 +197,11 @@ type BulkUpdateAccountResult struct {
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
Success int `json:"success"`
Failed int `json:"failed"`
SuccessIDs []int64 `json:"success_ids"`
FailedIDs []int64 `json:"failed_ids"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct {
@@ -219,23 +231,35 @@ type GenerateRedeemCodesInput struct {
ValidityDays int // 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
type ProxyBatchDeleteResult struct {
DeletedIDs []int64 `json:"deleted_ids"`
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
}
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyBatchDeleteSkipped struct {
ID int64 `json:"id"`
Reason string `json:"reason"`
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
IP string
City string
Region string
Country string
CountryCode string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
@@ -245,14 +269,16 @@ type ProxyExitInfoProber interface {
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewAdminService creates a new AdminService
@@ -265,16 +291,20 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -324,6 +354,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
if input.Email != "" {
user.Email = input.Email
@@ -356,6 +388,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
@@ -394,6 +431,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
}
return nil
}
@@ -421,6 +461,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
balanceDiff := user.Balance - oldBalance
if s.authCacheInvalidator != nil && balanceDiff != 0 {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
go func() {
@@ -432,7 +476,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
@@ -545,6 +588,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -577,18 +621,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return fmt.Errorf("cannot set self as fallback group")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
visited := map[int64]struct{}{}
nextID := fallbackGroupID
for {
if _, seen := visited[nextID]; seen {
return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 降级分组不能启用 claude_code_only否则会造成死循环
if fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
return nil
// 降级分组不能启用 claude_code_only否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
@@ -658,13 +717,32 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
// 模型路由配置
if input.ModelRouting != nil {
group.ModelRouting = input.ModelRouting
}
if input.ModelRoutingEnabled != nil {
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
var groupKeys []string
if s.authCacheInvalidator != nil {
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
if err == nil {
groupKeys = keys
}
}
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
@@ -683,6 +761,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
}()
}
if s.authCacheInvalidator != nil {
for _, key := range groupKeys {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
}
}
return nil
}
@@ -769,6 +852,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} else {
account.AutoPauseOnExpired = true
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -821,6 +910,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
account.Status = input.Status
}
@@ -871,7 +966,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
@@ -892,6 +989,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
@@ -909,6 +1012,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.Status != "" {
repoUpdates.Status = &input.Status
}
@@ -935,6 +1041,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
@@ -944,6 +1051,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
@@ -953,6 +1061,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
@@ -960,6 +1069,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = true
result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry)
}
@@ -1019,6 +1129,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
if err != nil {
return nil, 0, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, result.Total, nil
}
@@ -1027,7 +1138,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
if err != nil {
return nil, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, nil
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
@@ -1047,6 +1163,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
// Probe latency asynchronously so creation isn't blocked by network timeout.
go s.probeProxyLatency(context.Background(), proxy)
return proxy, nil
}
@@ -1085,12 +1203,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
return err
}
if count > 0 {
return ErrProxyInUse
}
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []Account{}, 0, nil
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
result := &ProxyBatchDeleteResult{}
if len(ids) == 0 {
return result, nil
}
for _, id := range ids {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
if count > 0 {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: ErrProxyInUse.Error(),
})
continue
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
result.DeletedIDs = append(result.DeletedIDs, id)
}
return result, nil
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
@@ -1190,23 +1349,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return &ProxyTestResult{
Success: false,
Message: err.Error(),
}, nil
}
latency := latencyMs
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: exitInfo.IP,
City: exitInfo.City,
Region: exitInfo.Region,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
}, nil
}
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil {
return
}
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
if err != nil {
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return
}
latency := latencyMs
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
@@ -1256,6 +1461,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 {
return
}
ids := make([]int64, 0, len(proxies))
for i := range proxies {
ids = append(ids, proxies[i].ID)
}
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
if err != nil {
log.Printf("Warning: load proxy latency cache failed: %v", err)
return
}
for i := range proxies {
info := latencies[proxies[i].ID]
if info == nil {
continue
}
if info.Success {
proxies[i].LatencyStatus = "success"
proxies[i].LatencyMs = info.LatencyMs
} else {
proxies[i].LatencyStatus = "failed"
}
proxies[i].LatencyMessage = info.Message
proxies[i].IPAddress = info.IPAddress
proxies[i].Country = info.Country
proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region
proxies[i].City = info.City
}
}
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
if s.proxyLatencyCache == nil || info == nil {
return
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
log.Printf("Warning: store proxy latency cache failed: %v", err)
}
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {

View File

@@ -0,0 +1,80 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type accountRepoStubForBulkUpdate struct {
accountRepoStub
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
if s.bulkUpdateErr != nil {
return 0, s.bulkUpdateErr
}
return int64(len(ids)), nil
}
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
if err, ok := s.bindGroupErrByID[accountID]; ok {
return err
}
return nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2, 3},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 3, result.Success)
require.Equal(t, 0, result.Failed)
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
require.Empty(t, result.FailedIDs)
require.Len(t, result.Results, 3)
}
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
bindGroupErrByID: map[int64]error{
2: errors.New("bind failed"),
},
}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
schedulable := false
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2, 3},
GroupIDs: &groupIDs,
Schedulable: &schedulable,
SkipMixedChannelCheck: true,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 2, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}

View File

@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call")
}
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call")
}
@@ -149,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
}
type proxyRepoStub struct {
deleteErr error
deletedIDs []int64
deleteErr error
countErr error
accountCount int64
deletedIDs []int64
}
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
@@ -195,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
}
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
panic("unexpected CountAccountsByProxyID call")
if s.countErr != nil {
return 0, s.countErr
}
return s.accountCount, nil
}
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
panic("unexpected ListAccountSummariesByProxyID call")
}
type redeemRepoStub struct {
@@ -405,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
require.Equal(t, []int64{404}, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
repo := &proxyRepoStub{accountCount: 2}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 77)
require.ErrorIs(t, err, ErrProxyInUse)
require.Empty(t, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_Error(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &proxyRepoStub{deleteErr: deleteErr}

View File

@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type balanceUserRepoStub struct {
*userRepoStub
updateErr error
updated []*User
}
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
if s.updateErr != nil {
return s.updateErr
}
if user == nil {
return nil
}
clone := *user
s.updated = append(s.updated, &clone)
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
type balanceRedeemRepoStub struct {
*redeemRepoStub
created []*RedeemCode
}
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
clone := *code
s.created = append(s.created, &clone)
return nil
}
type authCacheInvalidatorStub struct {
userIDs []int64
groupIDs []int64
keys []string
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
s.keys = append(s.keys, key)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
s.userIDs = append(s.userIDs, userID)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
s.groupIDs = append(s.groupIDs, groupID)
}
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
require.NoError(t, err)
require.Equal(t, []int64{7}, invalidator.userIDs)
require.Len(t, redeemRepo.created, 1)
}
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
require.NoError(t, err)
require.Empty(t, invalidator.userIDs)
require.Empty(t, redeemRepo.created)
}

View File

@@ -12,6 +12,7 @@ import (
mathrand "math/rand"
"net"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
@@ -28,6 +29,8 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
// antigravityRetryLoopParams 重试循环的参数
type antigravityRetryLoopParams struct {
ctx context.Context
@@ -38,7 +41,9 @@ type antigravityRetryLoopParams struct {
action string
body []byte
quotaScope AntigravityQuotaScope
c *gin.Context
httpUpstream HTTPUpstream
settingService *SettingService
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
}
@@ -56,6 +61,17 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
var resp *http.Response
var usedBaseURL string
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
getUpstreamDetail := func(body []byte) string {
if !logBody {
return ""
}
return truncateString(string(body), maxBytes)
}
urlFallbackLoop:
for urlIdx, baseURL := range availableURLs {
@@ -73,8 +89,22 @@ urlFallbackLoop:
return nil, err
}
// Capture upstream request body for ops retry of this attempt.
if p.c != nil && len(p.body) > 0 {
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
}
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
@@ -89,6 +119,7 @@ urlFallbackLoop:
continue
}
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
setOpsUpstreamError(p.c, 0, safeErr, "")
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
}
@@ -99,13 +130,37 @@ urlFallbackLoop:
// "Resource has been exhausted" 是 URL 级别限流,切换 URL
if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", p.prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
continue urlFallbackLoop
}
// 账户/模型配额限流,重试 3 次(指数退避)
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
@@ -131,6 +186,18 @@ urlFallbackLoop:
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
Platform: p.account.Platform,
AccountID: p.account.ID,
AccountName: p.account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: upstreamMsg,
Detail: getUpstreamDetail(respBody),
})
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
@@ -679,6 +746,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL = account.Proxy.URL()
}
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
sanitizeThinkingBlocks(&claudeReq)
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts := s.getClaudeTransformOptions(ctx)
@@ -690,6 +760,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// Safety net: ensure no cache_control leaked into Gemini request
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action := "streamGenerateContent"
@@ -704,7 +777,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
action: action,
body: geminiBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
})
if err != nil {
@@ -720,6 +795,28 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
@@ -753,6 +850,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr != nil {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
})
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
continue
}
@@ -766,6 +871,26 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
kind := "signature_retry"
if strings.TrimSpace(stage.name) != "" {
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
}
retryUpstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(retryBody))
retryUpstreamMsg = sanitizeUpstreamErrorMessage(retryUpstreamMsg)
retryUpstreamDetail := ""
if logBody {
retryUpstreamDetail = truncateString(string(retryBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: kind,
Message: retryUpstreamMsg,
Detail: retryUpstreamDetail,
})
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
@@ -793,10 +918,31 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
}
}
@@ -879,6 +1025,143 @@ func extractAntigravityErrorMessage(body []byte) string {
return ""
}
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
// This should not be needed if transformation is correct, but serves as a safety net
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
// Try a more robust approach: parse and clean
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
return body
}
cleaned := removeCacheControlFromAny(data)
if !cleaned {
return body
}
if result, err := json.Marshal(data); err == nil {
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
return result
}
return body
}
// removeCacheControlFromAny recursively removes cache_control fields
func removeCacheControlFromAny(v any) bool {
cleaned := false
switch val := v.(type) {
case map[string]any:
for k, child := range val {
if k == "cache_control" {
delete(val, k)
cleaned = true
} else if removeCacheControlFromAny(child) {
cleaned = true
}
}
case []any:
for _, item := range val {
if removeCacheControlFromAny(item) {
cleaned = true
}
}
}
return cleaned
}
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
if req == nil {
return
}
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
// Clean system blocks
if len(req.System) > 0 {
var systemBlocks []map[string]any
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
for i := range systemBlocks {
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
if removeCacheControlFromAny(systemBlocks[i]) {
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
}
}
}
// Marshal back
if cleaned, err := json.Marshal(systemBlocks); err == nil {
req.System = cleaned
}
}
}
// Clean message content blocks and flatten history
lastMsgIdx := len(req.Messages) - 1
for msgIdx := range req.Messages {
raw := req.Messages[msgIdx].Content
if len(raw) == 0 {
continue
}
// Try to parse as blocks array
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
cleaned := false
for blockIdx := range blocks {
blockType, _ := blocks[blockIdx]["type"].(string)
// Check for thinking blocks (typed or untyped)
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
// 1. Clean cache_control
if removeCacheControlFromAny(blocks[blockIdx]) {
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
cleaned = true
}
// 2. Flatten to text if it's a history message (not the last one)
if msgIdx < lastMsgIdx {
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
// Extract thinking content
var textContent string
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
textContent = t
} else {
// Fallback for non-string content (marshal it)
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
textContent = string(b)
}
}
// Convert to text block
blocks[blockIdx]["type"] = "text"
blocks[blockIdx]["text"] = textContent
delete(blocks[blockIdx], "thinking")
delete(blocks[blockIdx], "signature")
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
cleaned = true
}
}
}
// Marshal back if modified
if cleaned {
if marshaled, err := json.Marshal(blocks); err == nil {
req.Messages[msgIdx].Content = marshaled
}
}
}
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
@@ -1184,7 +1467,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
action: upstreamAction,
body: wrappedBody,
quotaScope: quotaScope,
c: c,
httpUpstream: s.httpUpstream,
settingService: s.settingService,
handleError: s.handleUpstreamError,
})
if err != nil {
@@ -1234,22 +1519,62 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 解包并返回错误
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody)
unwrappedForOps := unwrapped
if unwrapErr != nil || len(unwrappedForOps) == 0 {
unwrappedForOps = respBody
}
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(unwrappedForOps), maxBytes)
}
// Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500))
c.Data(resp.StatusCode, contentType, unwrapped)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
c.Data(resp.StatusCode, contentType, unwrappedForOps)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
@@ -1338,9 +1663,15 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func antigravityUseScopeRateLimit() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != ""
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 解析失败:使用配置的 fallback 时间,直接限流整个账户
@@ -1350,19 +1681,30 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
}
defaultDur := time.Duration(fallbackMinutes) * time.Minute
ra := time.Now().Add(defaultDur)
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return
}
resetTime := time.Unix(*resetAt, 0)
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if quotaScope == "" {
return
}
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
if useScopeLimit {
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
}
} else {
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
}
}
return
}
@@ -1533,6 +1875,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
continue
}
log.Printf("Stream data interval timeout (antigravity)")
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
@@ -1824,9 +2167,36 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int,
return fmt.Errorf("%s", message)
}
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
// 记录上游错误详情便于调试
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, string(body))
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
maxBytes := 2048
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
}
upstreamDetail := ""
if logBody {
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
if logBody {
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
}
var statusCode int
var errType, errMsg string
@@ -1862,7 +2232,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstr
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
return fmt.Errorf("upstream error: %d", upstreamStatus)
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
@@ -2189,6 +2562,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
continue
}
log.Printf("Stream data interval timeout (antigravity)")
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}

View File

@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if !a.IsSchedulable() {
return false
}
if a.isModelRateLimited(requestedModel) {
return false
}
if a.Platform != PlatformAntigravity {
return true
}

View File

@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("not an antigravity oauth account")
}
cacheKey := antigravityTokenCacheKey(account)
cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
func antigravityTokenCacheKey(account *Account) string {
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID

View File

@@ -3,16 +3,18 @@ package service
import "time"
type APIKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
IPWhitelist []string
IPBlacklist []string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *APIKey) IsActive() bool {

View File

@@ -0,0 +1,51 @@
package service
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
User APIKeyAuthUserSnapshot `json:"user"`
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
}
// APIKeyAuthUserSnapshot 用户快照
type APIKeyAuthUserSnapshot struct {
ID int64 `json:"id"`
Status string `json:"status"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
}
// APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
type APIKeyAuthCacheEntry struct {
NotFound bool `json:"not_found"`
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
}

View File

@@ -0,0 +1,273 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
l2TTL time.Duration
negativeTTL time.Duration
jitterPercent int
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
}
auth := cfg.APIKeyAuth
return apiKeyAuthCacheConfig{
l1Size: auth.L1Size,
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
jitterPercent: auth.JitterPercent,
singleflight: auth.Singleflight,
}
}
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
return c.l1Size > 0 && c.l1TTL > 0
}
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
return c.l2TTL > 0
}
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
}
if c.jitterPercent <= 0 {
return ttl
}
percent := c.jitterPercent
if percent > 100 {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
if !s.authCfg.l1Enabled() {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(s.authCfg.l1Size) * 10,
MaxCost: int64(s.authCfg.l1Size),
BufferItems: 64,
})
if err != nil {
return
}
s.authCacheL1 = cache
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
if s.authCacheL1 != nil {
if val, ok := s.authCacheL1.Get(cacheKey); ok {
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
return entry, true
}
}
}
if s.cache == nil || !s.authCfg.l2Enabled() {
return nil, false
}
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
if err != nil {
return nil, false
}
s.setAuthCacheL1(cacheKey, entry)
return entry, true
}
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
if s.authCacheL1 == nil || entry == nil {
return
}
ttl := s.authCfg.l1TTL
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
ttl = s.authCfg.negativeTTL
}
ttl = s.authCfg.jitterTTL(ttl)
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
}
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
if entry == nil {
return
}
s.setAuthCacheL1(cacheKey, entry)
if s.cache == nil || !s.authCfg.l2Enabled() {
return
}
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
}
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
if s.authCacheL1 != nil {
s.authCacheL1.Del(cacheKey)
}
if s.cache == nil {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
if errors.Is(err, ErrAPIKeyNotFound) {
entry := &APIKeyAuthCacheEntry{NotFound: true}
if s.authCfg.negativeEnabled() {
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
}
return entry, nil
}
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
return entry, nil
}
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
if entry == nil {
return nil, false, nil
}
if entry.NotFound {
return nil, true, ErrAPIKeyNotFound
}
if entry.Snapshot == nil {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
},
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
}
}
return snapshot
}
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
if snapshot == nil {
return nil
}
apiKey := &APIKey{
ID: snapshot.APIKeyID,
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
},
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
}
}
return apiKey
}

View File

@@ -0,0 +1,48 @@
package service
import "context"
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
if key == "" {
return
}
cacheKey := s.authCacheKey(key)
s.deleteAuthCache(ctx, cacheKey)
}
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
if userID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
if groupID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
if len(keys) == 0 {
return
}
for _, key := range keys {
if key == "" {
continue
}
s.deleteAuthCache(ctx, s.authCacheKey(key))
}
}

View File

@@ -9,8 +9,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
var (
@@ -20,6 +23,7 @@ var (
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
)
const (
@@ -29,9 +33,11 @@ const (
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID用于删除等轻量场景
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
// GetByKeyForAuth 认证专用查询,返回最小字段集
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
@@ -43,6 +49,8 @@ type APIKeyRepository interface {
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
}
// APIKeyCache defines cache operations for API key service
@@ -53,20 +61,35 @@ type APIKeyCache interface {
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
type APIKeyAuthCacheInvalidator interface {
InvalidateAuthCacheByKey(ctx context.Context, key string)
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
}
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
}
// APIKeyService API Key服务
@@ -77,6 +100,9 @@ type APIKeyService struct {
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -88,7 +114,7 @@ func NewAPIKeyService(
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
return &APIKeyService{
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -96,6 +122,8 @@ func NewAPIKeyService(
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
}
// GenerateKey 生成随机API Key
@@ -186,6 +214,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("get user: %w", err)
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
@@ -236,17 +278,21 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 创建API Key记录
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
@@ -282,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
cacheKey := s.authCacheKey(key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
if s.authCfg.singleflight {
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
return s.loadAuthCacheEntry(ctx, key, cacheKey)
})
if err != nil {
return nil, err
}
entry, _ := value.(*APIKeyAuthCacheEntry)
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
} else {
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
if err != nil {
return nil, err
}
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
apiKey.Key = key
return apiKey, nil
}
@@ -312,6 +386,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, ErrInsufficientPerms
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
@@ -344,19 +432,22 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
@@ -366,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
return ErrInsufficientPerms
}
// 清除Redis缓存使用 ownerID 而非 apiKey.UserID
// 清除Redis缓存使用 userID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
}
s.InvalidateAuthCacheByKey(ctx, key)
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)

View File

@@ -0,0 +1,423 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
type authRepoStub struct {
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
}
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
if s.getByKeyForAuth == nil {
panic("unexpected GetByKeyForAuth call")
}
return s.getByKeyForAuth(ctx, key)
}
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
if s.listKeysByUserID == nil {
panic("unexpected ListKeysByUserID call")
}
return s.listKeysByUserID(ctx, userID)
}
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
if s.listKeysByGroupID == nil {
panic("unexpected ListKeysByGroupID call")
}
return s.listKeysByGroupID(ctx, groupID)
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string
deleteAuthKeys []string
}
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
if s.getAuthCache == nil {
return nil, redis.Nil
}
return s.getAuthCache(ctx, key)
}
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
s.setAuthKeys = append(s.setAuthKeys, key)
return nil
}
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "g",
Platform: PlatformAnthropic,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-opus-*": {1, 2},
},
},
},
}
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return cacheEntry, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k1")
require.NoError(t, err)
require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID)
require.True(t, apiKey.Group.ModelRoutingEnabled)
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return &APIKey{
ID: 5,
UserID: 7,
Status: StatusActive,
User: &User{
ID: 7,
Status: StatusActive,
Role: RoleUser,
Balance: 12,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
apiKey, err := svc.GetByKey(context.Background(), "k2")
require.NoError(t, err)
require.Equal(t, int64(5), apiKey.ID)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
return &APIKey{
ID: 21,
UserID: 3,
Status: StatusActive,
User: &User{
ID: 3,
Status: StatusActive,
Role: RoleUser,
Balance: 5,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L1Size: 1000,
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
svc.authCacheL1.Wait()
cacheKey := svc.authCacheKey("k-l1")
_, ok := svc.authCacheL1.Get(cacheKey)
require.True(t, ok)
_, err = svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return nil, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, ErrAPIKeyNotFound
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
time.Sleep(50 * time.Millisecond)
return &APIKey{
ID: 11,
UserID: 2,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 1,
Concurrency: 1,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}
errs := make([]error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
<-start
_, err := svc.GetByKey(context.Background(), "k1")
errs[idx] = err
}(i)
}
close(start)
wg.Wait()
for _, err := range errs {
require.NoError(t, err)
}
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}

View File

@@ -20,13 +20,12 @@ import (
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
ownerID int64 // GetOwnerID 的返回值
ownerErr error // GetOwnerID 的错误返回值
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
}
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
if s.getByIDErr != nil {
return nil, s.getByIDErr
}
if s.apiKey != nil {
clone := *s.apiKey
return &clone, nil
}
panic("unexpected GetByID call")
}
// GetOwnerID 返回预设的所有者 ID 或错误。
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return s.ownerID, s.ownerErr
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
if s.getByIDErr != nil {
return "", 0, s.getByIDErr
}
if s.apiKey != nil {
return s.apiKey.Key, s.apiKey.UserID, nil
}
return "", 0, ErrAPIKeyNotFound
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
// - invalidated: 记录被清除缓存的用户 ID 列表
type apiKeyCacheStub struct {
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
}
// GetCreateAttemptCount 返回 0表示用户未超过创建次数限制
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - GetKeyAndOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.ErrorIs(t, err, ErrInsufficientPerms)
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
require.Empty(t, cache.invalidated) // 验证缓存未被清除
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - GetKeyAndOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - GetKeyAndOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
require.ErrorContains(t, err, "delete api key")
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}

View File

@@ -0,0 +1,33 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &UsageService{authCacheInvalidator: invalidator}
svc.invalidateUsageCaches(context.Background(), 7, false)
require.Empty(t, invalidator.userIDs)
svc.invalidateUsageCaches(context.Background(), 7, true)
require.Equal(t, []int64{7}, invalidator.userIDs)
}
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &RedeemService{authCacheInvalidator: invalidator}
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
groupID := int64(3)
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
}

View File

@@ -52,6 +52,7 @@ type AuthService struct {
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
}
// NewAuthService 创建认证服务实例
@@ -62,6 +63,7 @@ func NewAuthService(
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
promoService *PromoService,
) *AuthService {
return &AuthService{
userRepo: userRepo,
@@ -70,16 +72,17 @@ func NewAuthService(
emailService: emailService,
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
}
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
return s.RegisterWithVerification(ctx, email, password, "", "")
}
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// RegisterWithVerification 用户注册(支持邮件验证和优惠码返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
} else {
// 重新获取用户信息以获取更新后的余额
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
user = updatedUser
}
}
}
// 生成token
token, err := s.GenerateToken(user)
if err != nil {
@@ -341,7 +357,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth例如 OpenAI/Gemini
// 注意:该函数用于 LinuxDo OAuth 登录场景(不同于上游账号的 OAuth例如 Claude/OpenAI/Gemini
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
email = strings.TrimSpace(email)
@@ -360,8 +376,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
// OAuth 首次登录视为注册fail-closesettingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}

View File

@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
emailService,
nil,
nil,
nil, // promoService
)
}
@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}

View File

@@ -0,0 +1,208 @@
package service
import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const (
claudeTokenRefreshSkew = 3 * time.Minute
claudeTokenCacheSkew = 5 * time.Minute
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
}
func NewClaudeTokenProvider(
accountRepo AccountRepository,
tokenCache ClaudeTokenCache,
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
// 构建新 credentials保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}

View File

@@ -0,0 +1,939 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type claudeTokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
return &claudeTokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type claudeAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type claudeOAuthServiceStub struct {
tokenInfo *TokenInfo
refreshErr error
refreshCalled int32
}
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type testClaudeTokenProvider struct {
accountRepo *claudeAccountRepoStub
tokenCache *claudeTokenCacheStub
oauthService *claudeOAuthServiceStub
}
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// Build new credentials
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > claudeTokenCacheSkew {
ttl = until - claudeTokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
cache := newClaudeTokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := ClaudeTokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(time.Hour).Unix(),
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.simulateLockRace = true
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 106,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
getErr: errors.New("db connection failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 113,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still work, just using the passed-in account
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
updateErr: errors.New("db write failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 114,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still return token even if update fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 115,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-access-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
"custom_field": "should-be-preserved",
"organization": "test-org",
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "new-access-token", token)
// Verify existing fields are preserved
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
// Verify new fields are updated
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
}
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 116,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := ClaudeTokenCacheKey(account)
// After lock is acquired, cache should have the token (simulating another worker)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other-worker"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
// Tests for real provider - to increase coverage
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 300,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := ClaudeTokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account := &Account{
ID: 302,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
// After lock wait, return token from credentials
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
cacheKey := "claude:account:303"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 303,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token)
}
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 304,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 305,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 306,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}

View File

@@ -0,0 +1,258 @@
package service
import (
"context"
"errors"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const (
defaultDashboardAggregationTimeout = 2 * time.Minute
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
dashboardAggregationRetentionInterval = 6 * time.Hour
)
var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
// DashboardAggregationService 负责定时聚合与回填。
type DashboardAggregationService struct {
repo DashboardAggregationRepository
timingWheel *TimingWheelService
cfg config.DashboardAggregationConfig
running int32
lastRetentionCleanup atomic.Value // time.Time
}
// NewDashboardAggregationService 创建聚合服务。
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
var aggCfg config.DashboardAggregationConfig
if cfg != nil {
aggCfg = cfg.DashboardAgg
}
return &DashboardAggregationService{
repo: repo,
timingWheel: timingWheel,
cfg: aggCfg,
}
}
// Start 启动定时聚合作业(重启生效配置)。
func (s *DashboardAggregationService) Start() {
if s == nil || s.repo == nil || s.timingWheel == nil {
return
}
if !s.cfg.Enabled {
log.Printf("[DashboardAggregation] 聚合作业已禁用")
return
}
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
if interval <= 0 {
interval = time.Minute
}
if s.cfg.RecomputeDays > 0 {
go s.recomputeRecentDays()
}
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
s.runScheduledAggregation()
})
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
if !s.cfg.BackfillEnabled {
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
}
}
// TriggerBackfill 触发回填(异步)。
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.BackfillEnabled {
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
return ErrDashboardBackfillDisabled
}
if !end.After(start) {
return errors.New("回填时间范围无效")
}
if s.cfg.BackfillMaxDays > 0 {
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
if end.Sub(start) > maxRange {
return ErrDashboardBackfillTooLarge
}
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, end); err != nil {
log.Printf("[DashboardAggregation] 回填失败: %v", err)
}
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays
if days <= 0 {
return
}
now := time.Now().UTC()
start := now.AddDate(0, 0, -days)
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, now); err != nil {
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
return
}
}
func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
defer cancel()
now := time.Now().UTC()
last, err := s.repo.GetAggregationWatermark(ctx)
if err != nil {
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
last = time.Unix(0, 0).UTC()
}
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
epoch := time.Unix(0, 0).UTC()
start := last.Add(-lookback)
if !last.After(epoch) {
retentionDays := s.cfg.Retention.UsageLogsDays
if retentionDays <= 0 {
retentionDays = 1
}
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
} else if start.After(now) {
start = now.Add(-lookback)
}
if err := s.aggregateRange(ctx, start, now); err != nil {
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
return
}
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
if updateErr != nil {
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
}
log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
start.Format(time.RFC3339),
now.Format(time.RFC3339),
time.Since(jobStart).String(),
updateErr == nil,
)
s.maybeCleanupRetention(ctx, now)
}
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("聚合作业正在运行")
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
startUTC := start.UTC()
endUTC := end.UTC()
if !endUTC.After(startUTC) {
return errors.New("回填时间范围无效")
}
cursor := truncateToDayUTC(startUTC)
for cursor.Before(endUTC) {
windowEnd := cursor.Add(24 * time.Hour)
if windowEnd.After(endUTC) {
windowEnd = endUTC
}
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
return err
}
cursor = windowEnd
}
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
if updateErr != nil {
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
}
log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
startUTC.Format(time.RFC3339),
endUTC.Format(time.RFC3339),
time.Since(jobStart).String(),
updateErr == nil,
)
s.maybeCleanupRetention(ctx, endUTC)
return nil
}
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
if !end.After(start) {
return nil
}
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
}
return s.repo.AggregateRange(ctx, start, end)
}
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
lastAny := s.lastRetentionCleanup.Load()
if lastAny != nil {
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
return
}
}
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
}
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
if usageErr != nil {
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
s.lastRetentionCleanup.Store(now)
}
}
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}

View File

@@ -0,0 +1,106 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
s.aggregateCalls++
s.lastStart = start
s.lastEnd = end
return s.aggregateErr
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil
}
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return s.cleanupAggregatesErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.aggregateCalls)
require.False(t, repo.lastEnd.IsZero())
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
}
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
BackfillEnabled: true,
BackfillMaxDays: 1,
},
}
start := time.Now().AddDate(0, 0, -3)
end := time.Now()
err := svc.TriggerBackfill(start, end)
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
require.Equal(t, 0, repo.aggregateCalls)
}

View File

@@ -2,47 +2,307 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
// DashboardService provides aggregated statistics for admin dashboard.
type DashboardService struct {
usageRepo UsageLogRepository
const (
defaultDashboardStatsFreshTTL = 15 * time.Second
defaultDashboardStatsCacheTTL = 30 * time.Second
defaultDashboardStatsRefreshTimeout = 30 * time.Second
)
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
// DashboardStatsCache 定义仪表盘统计缓存接口。
type DashboardStatsCache interface {
GetDashboardStats(ctx context.Context) (string, error)
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
DeleteDashboardStats(ctx context.Context) error
}
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
type dashboardStatsRangeFetcher interface {
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
}
type dashboardStatsCacheEntry struct {
Stats *usagestats.DashboardStats `json:"stats"`
UpdatedAt int64 `json:"updated_at"`
}
// DashboardService 提供管理员仪表盘统计服务。
type DashboardService struct {
usageRepo UsageLogRepository
aggRepo DashboardAggregationRepository
cache DashboardStatsCache
cacheFreshTTL time.Duration
cacheTTL time.Duration
refreshTimeout time.Duration
refreshing int32
aggEnabled bool
aggInterval time.Duration
aggLookback time.Duration
aggUsageDays int
}
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
freshTTL := defaultDashboardStatsFreshTTL
cacheTTL := defaultDashboardStatsCacheTTL
refreshTimeout := defaultDashboardStatsRefreshTimeout
aggEnabled := true
aggInterval := time.Minute
aggLookback := 2 * time.Minute
aggUsageDays := 90
if cfg != nil {
if !cfg.Dashboard.Enabled {
cache = nil
}
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
}
if cfg.Dashboard.StatsTTLSeconds > 0 {
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
}
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
}
aggEnabled = cfg.DashboardAgg.Enabled
if cfg.DashboardAgg.IntervalSeconds > 0 {
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
}
if cfg.DashboardAgg.LookbackSeconds > 0 {
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
}
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
}
}
if aggRepo == nil {
aggEnabled = false
}
return &DashboardService{
usageRepo: usageRepo,
usageRepo: usageRepo,
aggRepo: aggRepo,
cache: cache,
cacheFreshTTL: freshTTL,
cacheTTL: cacheTTL,
refreshTimeout: refreshTimeout,
aggEnabled: aggEnabled,
aggInterval: aggInterval,
aggLookback: aggLookback,
aggUsageDays: aggUsageDays,
}
}
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.usageRepo.GetDashboardStats(ctx)
if s.cache != nil {
cached, fresh, err := s.getCachedDashboardStats(ctx)
if err == nil && cached != nil {
s.refreshAggregationStaleness(cached)
if !fresh {
s.refreshDashboardStatsAsync()
}
return cached, nil
}
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
}
}
stats, err := s.refreshDashboardStats(ctx)
if err != nil {
return nil, fmt.Errorf("get dashboard stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
return stats, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
return nil, false, err
}
var entry dashboardStatsCacheEntry
if err := json.Unmarshal([]byte(data), &entry); err != nil {
s.evictDashboardStatsCache(err)
return nil, false, ErrDashboardStatsCacheMiss
}
if entry.Stats == nil {
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
return nil, false, ErrDashboardStatsCacheMiss
}
age := time.Since(time.Unix(entry.UpdatedAt, 0))
return entry.Stats, age <= s.cacheFreshTTL, nil
}
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
return nil, err
}
s.applyAggregationStatus(ctx, stats)
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
s.saveDashboardStatsCache(cacheCtx, stats)
return stats, nil
}
func (s *DashboardService) refreshDashboardStatsAsync() {
if s.cache == nil {
return
}
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
return
}
go func() {
defer atomic.StoreInt32(&s.refreshing, 0)
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
defer cancel()
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
return
}
s.applyAggregationStatus(ctx, stats)
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
s.saveDashboardStatsCache(cacheCtx, stats)
}()
}
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
if !s.aggEnabled {
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
now := time.Now().UTC()
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
}
}
return s.usageRepo.GetDashboardStats(ctx)
}
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
if s.cache == nil || stats == nil {
return
}
entry := dashboardStatsCacheEntry{
Stats: stats,
UpdatedAt: time.Now().Unix(),
}
data, err := json.Marshal(entry)
if err != nil {
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
return
}
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
}
}
func (s *DashboardService) evictDashboardStatsCache(reason error) {
if s.cache == nil {
return
}
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
}
if reason != nil {
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
}
}
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), s.refreshTimeout)
}
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
if stats == nil {
return
}
updatedAt := s.fetchAggregationUpdatedAt(ctx)
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
}
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
if stats == nil {
return
}
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
}
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
if s.aggRepo == nil {
return time.Unix(0, 0).UTC()
}
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
if err != nil {
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
return time.Unix(0, 0).UTC()
}
if updatedAt.IsZero() {
return time.Unix(0, 0).UTC()
}
return updatedAt.UTC()
}
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
if !s.aggEnabled {
return true
}
epoch := time.Unix(0, 0).UTC()
if !updatedAt.After(epoch) {
return true
}
threshold := s.aggInterval + s.aggLookback
return now.Sub(updatedAt) > threshold
}
func parseStatsUpdatedAt(raw string) time.Time {
if raw == "" {
return time.Unix(0, 0).UTC()
}
parsed, err := time.Parse(time.RFC3339, raw)
if err != nil {
return time.Unix(0, 0).UTC()
}
return parsed.UTC()
}
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {

View File

@@ -0,0 +1,387 @@
package service
import (
"context"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
type usageRepoStub struct {
UsageLogRepository
stats *usagestats.DashboardStats
rangeStats *usagestats.DashboardStats
err error
rangeErr error
calls int32
rangeCalls int32
rangeStart time.Time
rangeEnd time.Time
onCall chan struct{}
}
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
atomic.AddInt32(&s.calls, 1)
if s.onCall != nil {
select {
case s.onCall <- struct{}{}:
default:
}
}
if s.err != nil {
return nil, s.err
}
return s.stats, nil
}
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
atomic.AddInt32(&s.rangeCalls, 1)
s.rangeStart = start
s.rangeEnd = end
if s.rangeErr != nil {
return nil, s.rangeErr
}
if s.rangeStats != nil {
return s.rangeStats, nil
}
return s.stats, nil
}
type dashboardCacheStub struct {
get func(ctx context.Context) (string, error)
set func(ctx context.Context, data string, ttl time.Duration) error
del func(ctx context.Context) error
getCalls int32
setCalls int32
delCalls int32
lastSetMu sync.Mutex
lastSet string
}
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
atomic.AddInt32(&c.getCalls, 1)
if c.get != nil {
return c.get(ctx)
}
return "", ErrDashboardStatsCacheMiss
}
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
atomic.AddInt32(&c.setCalls, 1)
c.lastSetMu.Lock()
c.lastSet = data
c.lastSetMu.Unlock()
if c.set != nil {
return c.set(ctx, data, ttl)
}
return nil
}
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
atomic.AddInt32(&c.delCalls, 1)
if c.del != nil {
return c.del(ctx)
}
return nil
}
type dashboardAggregationRepoStub struct {
watermark time.Time
err error
}
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil {
return time.Time{}, s.err
}
return s.watermark, nil
}
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
t.Helper()
c.lastSetMu.Lock()
data := c.lastSet
c.lastSetMu.Unlock()
var entry dashboardStatsCacheEntry
err := json.Unmarshal([]byte(data), &entry)
require.NoError(t, err)
return entry
}
func TestDashboardService_CacheHitFresh(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 10,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
entry := dashboardStatsCacheEntry{
Stats: stats,
UpdatedAt: time.Now().Unix(),
}
payload, err := json.Marshal(entry)
require.NoError(t, err)
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return string(payload), nil
},
}
repo := &usageRepoStub{
stats: &usagestats.DashboardStats{TotalUsers: 99},
}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
}
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 7,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "", ErrDashboardStatsCacheMiss
},
}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
entry := cache.readLastEntry(t)
require.Equal(t, stats, entry.Stats)
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
}
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 3,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "", nil
},
}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
}
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
staleStats := &usagestats.DashboardStats{
TotalUsers: 11,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
entry := dashboardStatsCacheEntry{
Stats: staleStats,
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
}
payload, err := json.Marshal(entry)
require.NoError(t, err)
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return string(payload), nil
},
}
refreshCh := make(chan struct{}, 1)
repo := &usageRepoStub{
stats: &usagestats.DashboardStats{TotalUsers: 22},
onCall: refreshCh,
}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, staleStats, got)
select {
case <-refreshCh:
case <-time.After(1 * time.Second):
t.Fatal("等待异步刷新超时")
}
require.Eventually(t, func() bool {
return atomic.LoadInt32(&cache.setCalls) >= 1
}, 1*time.Second, 10*time.Millisecond)
}
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "not-json", nil
},
}
stats := &usagestats.DashboardStats{TotalUsers: 9}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
}
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "not-json", nil
},
}
repo := &usageRepoStub{err: errors.New("db down")}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
_, err := svc.GetDashboardStats(context.Background())
require.Error(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
}
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
stats := &usagestats.DashboardStats{}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
svc := NewDashboardService(repo, aggRepo, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
require.True(t, got.StatsStale)
}
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
aggNow := time.Now().UTC().Truncate(time.Second)
stats := &usagestats.DashboardStats{}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
},
}
svc := NewDashboardService(repo, aggRepo, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
require.False(t, got.StatsStale)
}
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
expected := &usagestats.DashboardStats{TotalUsers: 42}
repo := &usageRepoStub{
rangeStats: expected,
err: errors.New("should not call aggregated stats"),
}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: false,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 7,
},
},
}
svc := NewDashboardService(repo, nil, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, int64(42), got.TotalUsers)
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
require.False(t, repo.rangeEnd.IsZero())
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
}

View File

@@ -38,6 +38,12 @@ const (
RedeemTypeSubscription = "subscription"
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
@@ -57,6 +63,9 @@ const (
SubscriptionStatusSuspended = "suspended"
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀RFC 保留域名)。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// Setting keys
const (
// 注册设置
@@ -77,6 +86,12 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
@@ -84,6 +99,7 @@ const (
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML或 URL 作为 iframe src
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
@@ -106,16 +122,38 @@ const (
SettingKeyEnableIdentityPatch = "enable_identity_patch"
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
// LinuxDo Connect OAuth 登录(终端用户 SSO
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
)
// =========================
// Ops Monitoring (vNext)
// =========================
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀RFC 保留域名)。
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// SettingKeyOpsMonitoringEnabled is a DB-backed soft switch to enable/disable ops module at runtime.
SettingKeyOpsMonitoringEnabled = "ops_monitoring_enabled"
// SettingKeyOpsRealtimeMonitoringEnabled controls realtime features (e.g. WS/QPS push).
SettingKeyOpsRealtimeMonitoringEnabled = "ops_realtime_monitoring_enabled"
// SettingKeyOpsQueryModeDefault controls the default query mode for ops dashboard (auto/raw/preagg).
SettingKeyOpsQueryModeDefault = "ops_query_mode_default"
// SettingKeyOpsEmailNotificationConfig stores JSON config for ops email notifications.
SettingKeyOpsEmailNotificationConfig = "ops_email_notification_config"
// SettingKeyOpsAlertRuntimeSettings stores JSON config for ops alert evaluator runtime settings.
SettingKeyOpsAlertRuntimeSettings = "ops_alert_runtime_settings"
// SettingKeyOpsMetricsIntervalSeconds controls the ops metrics collector interval (>=60).
SettingKeyOpsMetricsIntervalSeconds = "ops_metrics_interval_seconds"
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
// =========================
// Stream Timeout Handling
// =========================
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-"

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
accounts []Account
accountsByID map[int64]*Account
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
getByIDCalls int
}
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
m.getByIDCalls++
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
@@ -142,6 +145,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -157,6 +163,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearModelRateLimits(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -194,6 +203,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func ptr[T any](v T) *T {
return &v
}
@@ -900,6 +959,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
return m.accountWaitCounts[accountID], nil
}
type mockConcurrencyCache struct {
acquireAccountCalls int
loadBatchCalls int
}
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
m.acquireAccountCalls++
return true, nil
}
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
m.loadBatchCalls++
result := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
@@ -928,13 +1055,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
groupID := int64(1)
sessionHash := "sticky"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-a": {1},
"claude-b": {2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // legacy path
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -959,7 +1140,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -991,13 +1172,85 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
})
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
@@ -1016,9 +1269,264 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
})
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
now := time.Now()
overloadUntil := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
})
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
ctxGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
groupID := int64(42)
invalidGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
hydratedGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
svc := &GatewayService{}
ctx = svc.withGroupContext(ctx, hydratedGroup)
got, ok := ctx.Value(ctxkey.Group).(*Group)
require.True(t, ok)
require.Same(t, hydratedGroup, got)
}
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{fallbackID: fallbackGroup},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &groupID,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: group,
fallbackID: fallbackGroup,
},
}
svc := &GatewayService{
groupRepo: groupRepo,
}
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
require.Error(t, err)
require.Nil(t, gotGroup)
require.Nil(t, gotID)
require.Contains(t, err.Error(), "fallback group cycle")
}

File diff suppressed because it is too large Load Diff

View File

@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
accountRepo AccountRepository
groupRepo GroupRepository
cache GatewayCache
schedulerSnapshot *SchedulerSnapshotService
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
accountRepo AccountRepository,
groupRepo GroupRepository,
cache GatewayCache,
schedulerSnapshot *SchedulerSnapshotService,
tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
accountRepo: accountRepo,
groupRepo: groupRepo,
cache: cache,
schedulerSnapshot: schedulerSnapshot,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
@@ -86,9 +89,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
}
platform = group.Platform
} else {
@@ -99,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
var queryPlatforms []string
if useMixedScheduling {
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
} else {
queryPlatforms = []string{platform}
}
cacheKey := "gemini:" + sessionHash
@@ -112,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
@@ -143,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
var selected *Account
@@ -239,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
return s.antigravityGatewayService
}
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
}
return s.accountRepo.GetByID(ctx, accountID)
}
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
return accounts, err
}
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
queryPlatforms := []string{platform}
if useMixedScheduling {
queryPlatforms = []string{platform, PlatformAntigravity}
}
if groupID != nil {
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
}
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
@@ -260,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
}
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
if err != nil {
return false, err
}
@@ -282,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
@@ -535,14 +545,30 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
continue
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr)
}
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
@@ -552,6 +578,31 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
_ = resp.Body.Close()
if isGeminiSignatureRelatedError(respBody) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "signature_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
var strippedClaudeBody []byte
stageName := ""
switch signatureRetryStage {
@@ -602,6 +653,31 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
@@ -627,12 +703,64 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
}
requestID := resp.Header.Get(requestIDHeader)
@@ -855,8 +983,23 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = idHeader
// Capture upstream request body for ops retry of this attempt.
if c != nil {
// In this code path `body` is already the JSON sent to upstream.
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
sleepGeminiBackoff(attempt)
@@ -874,7 +1017,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
setOpsUpstreamError(c, 0, safeErr, "")
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
}
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
@@ -893,6 +1037,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if attempt < geminiMaxRetries {
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
upstreamReqID = resp.Header.Get("x-goog-request-id")
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
Message: upstreamMsg,
Detail: upstreamDetail,
})
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
sleepGeminiBackoff(attempt)
continue
@@ -956,19 +1125,87 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
if tempMatched {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(evBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
respBody = unwrapIfNeeded(isOAuth, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, respBody)
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
if upstreamMsg == "" {
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("gemini upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
var usage *ClaudeUsage
@@ -1070,7 +1307,33 @@ func sanitizeUpstreamErrorMessage(msg string) string {
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
}
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error {
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
var statusCode int
var errType, errMsg string
@@ -1178,7 +1441,10 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, ups
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
return fmt.Errorf("upstream error: %d", upstreamStatus)
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
}
type claudeErrorMapping struct {

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -127,6 +128,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -140,6 +144,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearModelRateLimits(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -155,10 +162,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type mockGroupRepoForGemini struct {
groups map[int64]*Group
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, errors.New("group not found")
}
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
@@ -251,6 +269,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
}
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(7)
group := &Group{
ID: groupID,
Platform: PlatformGemini,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
ctx := context.Background()
groupID := int64(7)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{
groups: map[int64]*Group{
groupID: {ID: groupID, Platform: PlatformGemini},
},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
ctx := context.Background()

View File

@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
DeleteAccessToken(ctx context.Context, cacheKey string) error
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
ReleaseRefreshLock(ctx context.Context, cacheKey string) error

View File

@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("not a gemini oauth account")
}
cacheKey := geminiTokenCacheKey(account)
cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func geminiTokenCacheKey(account *Account) string {
func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return projectID
return "gemini:" + projectID
}
return "account:" + strconv.FormatInt(account.ID, 10)
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
}

View File

@@ -1,6 +1,9 @@
package service
import "time"
import (
"strings"
"time"
)
type Group struct {
ID int64
@@ -10,6 +13,7 @@ type Group struct {
RateMultiplier float64
IsExclusive bool
Status string
Hydrated bool // indicates the group was loaded from a trusted repository source
SubscriptionType string
DailyLimitUSD *float64
@@ -26,6 +30,12 @@ type Group struct {
ClaudeCodeOnly bool
FallbackGroupID *int64
// 模型路由配置
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*"
// value: 优先账号 ID 列表
ModelRouting map[string][]int64
ModelRoutingEnabled bool
CreatedAt time.Time
UpdatedAt time.Time
@@ -72,3 +82,58 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return g.ImagePrice2K
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
return false
}
if group.ID <= 0 {
return false
}
if !group.Hydrated {
return false
}
if group.Platform == "" || group.Status == "" {
return false
}
return true
}
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
return nil
}
// 1. 精确匹配优先
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
return accountIDs
}
// 2. 通配符匹配(前缀匹配)
for pattern, accountIDs := range g.ModelRouting {
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
return accountIDs
}
}
return nil
}
// matchModelPattern 检查模型是否匹配模式
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
func matchModelPattern(pattern, model string) bool {
if pattern == model {
return true
}
// 处理 * 通配符(仅支持末尾通配符)
if strings.HasSuffix(pattern, "*") {
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(model, prefix)
}
return false
}

View File

@@ -16,6 +16,7 @@ var (
type GroupRepository interface {
Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*Group, error)
GetByIDLite(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
@@ -49,13 +50,15 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo GroupRepository
groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo GroupRepository) *GroupService {
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
return &GroupService{
groupRepo: groupRepo,
groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -154,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
@@ -166,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
return fmt.Errorf("get group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}

View File

@@ -0,0 +1,56 @@
package service
import (
"strings"
"time"
)
const modelRateLimitsKey = "model_rate_limits"
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
model := strings.ToLower(strings.TrimSpace(requestedModel))
if model == "" {
return "", false
}
model = strings.TrimPrefix(model, "models/")
if strings.Contains(model, "sonnet") {
return modelRateLimitScopeClaudeSonnet, true
}
return "", false
}
func (a *Account) isModelRateLimited(requestedModel string) bool {
scope, ok := resolveModelRateLimitScope(requestedModel)
if !ok {
return false
}
resetAt := a.modelRateLimitResetAt(scope)
if resetAt == nil {
return false
}
return time.Now().Before(*resetAt)
}
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
if !ok {
return nil
}
rawLimit, ok := rawLimits[scope].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}

View File

@@ -0,0 +1,528 @@
package service
import (
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
//go:embed prompts/codex_cli_instructions.md
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
Modified bool
NormalizedModel string
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
needsToolContinuation := NeedsToolContinuation(reqBody)
model := ""
if v, ok := reqBody["model"].(string); ok {
model = v
}
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" {
if model != normalizedModel {
reqBody["model"] = normalizedModel
result.Modified = true
}
result.NormalizedModel = normalizedModel
}
// OAuth 走 ChatGPT internal API 时store 必须为 false显式 true 也会强制覆盖。
// 避免上游返回 "Store must be set to false"。
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
}
if _, ok := reqBody["max_output_tokens"]; ok {
delete(reqBody, "max_output_tokens")
result.Modified = true
}
if _, ok := reqBody["max_completion_tokens"]; ok {
delete(reqBody, "max_completion_tokens")
result.Modified = true
}
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions = strings.TrimSpace(existingInstructions)
if instructions != "" {
if existingInstructions != instructions {
reqBody["instructions"] = instructions
result.Modified = true
}
} else if existingInstructions == "" {
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions != "" {
reqBody["instructions"] = codexInstructions
result.Modified = true
}
}
// 续链场景保留 item_reference 与 id避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
}
return result
}
func normalizeCodexModel(model string) string {
if model == "" {
return "gpt-5.1"
}
modelID := model
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
return mapped
}
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
}
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
return "gpt-5.1-codex-mini"
}
if strings.Contains(normalized, "codex-mini-latest") ||
strings.Contains(normalized, "gpt-5-codex-mini") ||
strings.Contains(normalized, "gpt 5 codex mini") {
return "codex-mini-latest"
}
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
return "gpt-5.1"
}
if strings.Contains(normalized, "codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
return "gpt-5.1"
}
return "gpt-5.1"
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""
}
if mapped, ok := codexModelMap[modelID]; ok {
return mapped
}
lower := strings.ToLower(modelID)
for key, value := range codexModelMap {
if strings.ToLower(key) == lower {
return value
}
}
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
// 优先从 opencode 仓库缓存获取指令。
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
// 若 opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
return getCodexCLIInstructions()
}
func getCodexCLIInstructions() string {
return codexCLIInstructions
}
func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
func GetCodexCLIInstructions() string {
return getCodexCLIInstructions()
}
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
if codexInstructions == "" {
return false
}
existingInstructions, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existingInstructions) != codexInstructions {
reqBody["instructions"] = codexInstructions
return true
}
return false
}
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
func IsInstructionError(errorMessage string) bool {
if errorMessage == "" {
return false
}
lowerMsg := strings.ToLower(errorMessage)
instructionKeywords := []string{
"instruction",
"instructions",
"system prompt",
"system message",
"invalid prompt",
"prompt format",
}
for _, keyword := range instructionKeywords {
if strings.Contains(lowerMsg, keyword) {
return true
}
}
return false
}
// filterCodexInput 按需过滤 item_reference 与 id。
// preserveReferences 为 true 时保持引用与 id以满足续链请求对上下文的依赖。
func filterCodexInput(input []any, preserveReferences bool) []any {
filtered := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
typ, _ := m["type"].(string)
if typ == "item_reference" {
if !preserveReferences {
continue
}
newItem := make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
filtered = append(filtered, newItem)
continue
}
newItem := m
copied := false
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
ensureCopy := func() {
if copied {
return
}
newItem = make(map[string]any, len(m))
for key, value := range m {
newItem[key] = value
}
copied = true
}
if isCodexToolCallItemType(typ) {
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
ensureCopy()
newItem["call_id"] = id
}
}
}
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
if !isCodexToolCallItemType(typ) {
delete(newItem, "call_id")
}
}
filtered = append(filtered, newItem)
}
return filtered
}
func isCodexToolCallItemType(typ string) bool {
if typ == "" {
return false
}
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
return false
}
tools, ok := rawTools.([]any)
if !ok {
return false
}
modified := false
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
toolType, _ := toolMap["type"].(string)
if strings.TrimSpace(toolType) != "function" {
continue
}
function, ok := toolMap["function"].(map[string]any)
if !ok {
continue
}
if _, ok := toolMap["name"]; !ok {
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
toolMap["name"] = name
modified = true
}
}
if _, ok := toolMap["description"]; !ok {
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
toolMap["description"] = desc
modified = true
}
}
if _, ok := toolMap["parameters"]; !ok {
if params, ok := function["parameters"]; ok {
toolMap["parameters"] = params
modified = true
}
}
if _, ok := toolMap["strict"]; !ok {
if strict, ok := function["strict"]; ok {
toolMap["strict"] = strict
modified = true
}
}
tools[idx] = toolMap
}
if modified {
reqBody["tools"] = tools
}
return modified
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}

View File

@@ -0,0 +1,167 @@
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id但不再强制 store=true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
"input": []any{
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
// 未显式设置 store=true默认为 false。
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 2)
// 校验 input[0] 为 map避免断言失败导致测试中断。
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "item_reference", first["type"])
require.Equal(t, "ref1", first["id"])
// 校验 input[1] 为 map确保后续字段断言安全。
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "o1", second["id"])
}
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true保持 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": false,
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
// 显式 store=true 也会强制为 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"store": true,
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
"tool_choice": "auto",
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
}
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{
map[string]any{"type": "text", "id": "t1", "text": "hi"},
},
}
applyCodexOAuthTransform(reqBody)
store, ok := reqBody["store"].(bool)
require.True(t, ok)
require.False(t, store)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 1)
// 校验 input[0] 为 map避免类型不匹配触发 errcheck。
item, ok := input[0].(map[string]any)
require.True(t, ok)
_, hasID := item["id"]
require.False(t, hasID)
}
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
input := []any{
map[string]any{"type": "item_reference", "id": "ref1"},
map[string]any{"type": "text", "id": "t1", "text": "hi"},
}
filtered := filterCodexInput(input, false)
require.Len(t, filtered, 1)
// 校验 filtered[0] 为 map确保字段检查可靠。
item, ok := filtered[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", item["type"])
_, hasID := item["id"]
require.False(t, hasID)
}
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
applyCodexOAuthTransform(reqBody)
input, ok := reqBody["input"].([]any)
require.True(t, ok)
require.Len(t, input, 0)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
}

View File

@@ -20,6 +20,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
@@ -41,6 +42,7 @@ var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"conversation_id": true,
"user-agent": true,
"originator": true,
"session_id": true,
@@ -84,12 +86,15 @@ type OpenAIGatewayService struct {
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -100,12 +105,14 @@ func NewOpenAIGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -114,12 +121,15 @@ func NewOpenAIGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
}
}
@@ -158,7 +168,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
@@ -169,16 +179,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
}
// 2. Get schedulable OpenAI accounts
var accounts []Account
var err error
// 简易模式:忽略分组限制,查询所有可用账号
if s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
@@ -190,6 +191,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
@@ -300,7 +306,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
@@ -336,6 +342,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
@@ -445,6 +457,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
return accounts, err
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
@@ -467,6 +483,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
}
return s.accountRepo.GetByID(ctx, accountID)
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
@@ -485,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
if s.openAITokenProvider != nil {
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 降级TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
@@ -511,7 +543,7 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
}
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
@@ -528,33 +560,97 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
// Apply model mapping
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
// 对所有请求执行模型映射(包含 Codex CLI
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel
bodyModified = true
}
// For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
// Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"content": "..."}
if normalizeInputForCodexAPI(reqBody) {
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
if model, ok := reqBody["model"].(string); ok {
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" && normalizedModel != model {
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
model, normalizedModel, account.Name, account.Type, isCodexCLI)
reqBody["model"] = normalizedModel
mappedModel = normalizedModel
bodyModified = true
}
}
// 规范化 reasoning.effort 参数minimal -> none与上游允许值对齐。
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
reasoning["effort"] = "none"
bodyModified = true
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
}
}
if account.Type == AccountTypeOAuth && !isCodexCLI {
codexResult := applyCodexOAuthTransform(reqBody)
if codexResult.Modified {
bodyModified = true
}
if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
}
}
// Handle max_output_tokens based on platform and account type
if !isCodexCLI {
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
switch account.Platform {
case PlatformOpenAI:
// For OpenAI API Key, remove max_output_tokens (not supported)
// For OpenAI OAuth (Responses API), keep it (supported)
if account.Type == AccountTypeAPIKey {
delete(reqBody, "max_output_tokens")
bodyModified = true
}
case PlatformAnthropic:
// For Anthropic (Claude), convert to max_tokens
delete(reqBody, "max_output_tokens")
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
reqBody["max_tokens"] = maxOutputTokens
}
bodyModified = true
case PlatformGemini:
// For Gemini, remove (will be handled by Gemini-specific transform)
delete(reqBody, "max_output_tokens")
bodyModified = true
default:
// For unknown platforms, remove to be safe
delete(reqBody, "max_output_tokens")
bodyModified = true
}
}
// Also handle max_completion_tokens (similar logic)
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
delete(reqBody, "max_completion_tokens")
bodyModified = true
}
}
}
// Re-serialize body only if modified
if bodyModified {
var err error
@@ -571,7 +667,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
if err != nil {
return nil, err
}
@@ -582,16 +678,63 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -632,7 +775,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
@@ -672,12 +815,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
// Set accept header based on stream mode
if isStream {
req.Header.Set("accept", "text/event-stream")
} else {
req.Header.Set("accept", "application/json")
}
}
// Whitelist passthrough headers
@@ -689,6 +826,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
}
}
if account.Type == AccountTypeOAuth {
req.Header.Set("OpenAI-Beta", "responses=experimental")
if isCodexCLI {
req.Header.Set("originator", "codex_cli_rs")
} else {
req.Header.Set("originator", "opencode")
}
req.Header.Set("accept", "text/event-stream")
if promptCacheKey != "" {
req.Header.Set("conversation_id", promptCacheKey)
req.Header.Set("session_id", promptCacheKey)
}
}
// Apply custom User-Agent if configured
customUA := account.GetOpenAIUserAgent()
@@ -705,17 +855,53 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream gateway error",
},
})
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Handle upstream error (mark account status)
@@ -723,6 +909,20 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -761,7 +961,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
},
})
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
@@ -905,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
@@ -933,6 +1141,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
@@ -988,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
// correctToolCallsInResponseBody 修正响应体中的工具调用
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
if len(body) == 0 {
return body
}
bodyStr := string(body)
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
if changed {
return []byte(corrected)
}
return body
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
@@ -1016,6 +1242,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return nil, err
}
if account.Type == AccountTypeOAuth {
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
}
}
// Parse usage
var response struct {
Usage struct {
@@ -1055,6 +1288,112 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return usage, nil
}
func isEventStreamResponse(header http.Header) bool {
contentType := strings.ToLower(header.Get("Content-Type"))
return strings.Contains(contentType, "text/event-stream")
}
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
usage := &OpenAIUsage{}
if ok {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(finalResponse, &response); err == nil {
usage.InputTokens = response.Usage.InputTokens
usage.OutputTokens = response.Usage.OutputTokens
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
}
body = finalResponse
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
}
body = []byte(bodyText)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json; charset=utf-8"
if !ok {
contentType = resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream"
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
var event struct {
Type string `json:"type"`
Response json.RawMessage `json:"response"`
}
if json.Unmarshal([]byte(data), &event) != nil {
continue
}
if event.Type == "response.done" || event.Type == "response.completed" {
if len(event.Response) > 0 {
return event.Response, true
}
}
}
return nil, false
}
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
s.parseSSEUsage(data, usage)
}
return usage
}
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
}
return strings.Join(lines, "\n")
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
@@ -1094,101 +1433,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return newBody
}
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
input, ok := reqBody["input"]
if !ok {
return false
}
// Handle case where input is a simple string (already compatible)
if _, isString := input.(string); isString {
return false
}
// Handle case where input is an array of messages
inputArray, ok := input.([]any)
if !ok {
return false
}
modified := false
for _, item := range inputArray {
message, ok := item.(map[string]any)
if !ok {
continue
}
content, ok := message["content"]
if !ok {
continue
}
// If content is already a string, no conversion needed
if _, isString := content.(string); isString {
continue
}
// If content is an array (AI SDK format), convert to string
contentArray, ok := content.([]any)
if !ok {
continue
}
// Extract text from content array
var textParts []string
for _, part := range contentArray {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
// Handle different content types
partType, _ := partMap["type"].(string)
switch partType {
case "input_text", "text":
// Extract text from input_text or text type
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
case "input_image", "image":
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case "input_file", "file":
// Similar to images, file inputs may need special handling
continue
default:
// For unknown types, try to extract text if available
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
}
}
// Convert content array to string
if len(textParts) > 0 {
message["content"] = strings.Join(textParts, "\n")
modified = true
}
}
return modified
}
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
@@ -1197,6 +1441,7 @@ type OpenAIRecordUsageInput struct {
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage records usage and deducts balance
@@ -1242,28 +1487,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: actualInputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加 UserAgent
@@ -1271,6 +1518,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}

View File

@@ -3,6 +3,7 @@ package service
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net/http"
@@ -15,6 +16,129 @@ import (
"github.com/gin-gonic/gin"
)
type stubOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
return append([]Account(nil), r.accounts...), nil
}
type stubConcurrencyCache struct {
ConcurrencyCache
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
out := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
groupID := int64(1)
rateLimited := Account{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
RateLimitResetAt: &resetAt,
}
available := Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
// concurrencyService is nil, forcing the non-load-batch selection path.
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selection with account")
}
if selection.Account.ID != available.ID {
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
}
if selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
}
func TestOpenAIStreamingTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -220,7 +344,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}

View File

@@ -0,0 +1,133 @@
package service
import (
"strings"
"testing"
)
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
// 创建一个简单的 service 实例来测试工具修正
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
tests := []struct {
name string
input []byte
expected string
changed bool
}{
{
name: "correct apply_patch in response body",
input: []byte(`{
"choices": [{
"message": {
"tool_calls": [{
"function": {"name": "apply_patch"}
}]
}
}]
}`),
expected: "edit",
changed: true,
},
{
name: "correct update_plan in response body",
input: []byte(`{
"tool_calls": [{
"function": {"name": "update_plan"}
}]
}`),
expected: "todowrite",
changed: true,
},
{
name: "no change for correct tool name",
input: []byte(`{
"tool_calls": [{
"function": {"name": "edit"}
}]
}`),
expected: "edit",
changed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := service.correctToolCallsInResponseBody(tt.input)
resultStr := string(result)
// 检查是否包含期望的工具名称
if !strings.Contains(resultStr, tt.expected) {
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
}
// 对于预期有变化的情况,验证结果与输入不同
if tt.changed && string(result) == string(tt.input) {
t.Error("expected result to be different from input, but they are the same")
}
// 对于预期无变化的情况,验证结果与输入相同
if !tt.changed && string(result) != string(tt.input) {
t.Error("expected result to be same as input, but they are different")
}
})
}
}
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
if service.toolCorrector == nil {
t.Fatal("toolCorrector should not be nil")
}
// 测试修正器可以正常工作
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
if !changed {
t.Error("expected tool call to be corrected")
}
if !strings.Contains(corrected, "edit") {
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
}
}
// TestToolCorrectionStats 测试工具修正统计功能
func TestToolCorrectionStats(t *testing.T) {
service := &OpenAIGatewayService{
toolCorrector: NewCodexToolCorrector(),
}
// 执行几次修正
testData := []string{
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
}
for _, data := range testData {
service.toolCorrector.CorrectToolCallsInSSEData(data)
}
stats := service.toolCorrector.GetStats()
if stats.TotalCorrected != 3 {
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
}
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
}
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
}
}

View File

@@ -0,0 +1,189 @@
package service
import (
"context"
"errors"
"log/slog"
"strings"
"time"
)
const (
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond
)
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
}
func NewOpenAITokenProvider(
accountRepo AccountRepository,
tokenCache OpenAITokenCache,
openAIOAuthService *OpenAIOAuthService,
) *OpenAITokenProvider {
return &OpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(openAILockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > openAITokenCacheSkew:
ttl = until - openAITokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}

View File

@@ -0,0 +1,810 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// openAITokenCacheStub implements OpenAITokenCache for testing
type openAITokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newOpenAITokenCacheStub() *openAITokenCacheStub {
return &openAITokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
type openAIAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
type openAIOAuthServiceStub struct {
tokenInfo *OpenAITokenInfo
refreshErr error
refreshCalled int32
}
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
now := time.Now()
return map[string]any{
"access_token": info.AccessToken,
"refresh_token": info.RefreshToken,
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
}
}
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
cache := newOpenAITokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := OpenAITokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
ExpiresIn: 3600,
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// We need to directly test with the stub - create a custom provider
customProvider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := customProvider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
// testOpenAITokenProvider is a test version that uses the stub OAuth service
type testOpenAITokenProvider struct {
accountRepo *openAIAccountRepoStub
tokenCache *openAITokenCacheStub
oauthService *openAIOAuthServiceStub
}
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
return "", errors.New("not an openai oauth account")
}
cacheKey := OpenAITokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond) // Short wait for test
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > openAITokenCacheSkew {
ttl = until - openAITokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.simulateLockRace = true
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get the token set by the "winner" or the original
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
provider := NewOpenAITokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an openai oauth account")
require.Empty(t, token)
}
func TestOpenAITokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 106,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := OpenAITokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
cache := newOpenAITokenCacheStub()
accountRepo := &openAIAccountRepoStub{}
oauthService := &openAIOAuthServiceStub{
tokenInfo: &OpenAITokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := OpenAITokenCacheKey(account)
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
originalGet := int32(0)
cache.tokens[cacheKey] = "" // Empty initially
provider := &testOpenAITokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// In a goroutine, set the cached token after a small delay (simulating race)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other"
cache.mu.Unlock()
}()
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the refreshed token or the cached one
require.NotEmpty(t, token)
_ = originalGet // Suppress unused warning
}
// Tests for real provider - to increase coverage
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Should get either the fallback token or the refreshed one
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 201,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set) - should use credentials
account := &Account{
ID: 202,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
// Without OAuth service, refresh will fail but token should be returned from credentials
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
cacheKey := "openai:account:203"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 203,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token) // Should fall back to credentials
}
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 204,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 205,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
cache := newOpenAITokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 206,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}

View File

@@ -0,0 +1,213 @@
package service
import "strings"
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
if hasNonEmptyString(reqBody["previous_response_id"]) {
return true
}
if hasToolsSignal(reqBody) {
return true
}
if hasToolChoiceSignal(reqBody) {
return true
}
if inputHasType(reqBody, "function_call_output") {
return true
}
if inputHasType(reqBody, "item_reference") {
return true
}
return false
}
// HasFunctionCallOutput 判断 input 是否包含 function_call_output用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
return inputHasType(reqBody, "function_call_output")
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "tool_call" && itemType != "function_call" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
return true
}
}
return false
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
if reqBody == nil {
return nil
}
input, ok := reqBody["input"].([]any)
if !ok {
return nil
}
ids := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
ids[callID] = struct{}{}
}
}
if len(ids) == 0 {
return nil
}
result := make([]string, 0, len(ids))
for id := range ids {
result = append(result, id)
}
return result
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "function_call_output" {
continue
}
callID, _ := itemMap["call_id"].(string)
if strings.TrimSpace(callID) == "" {
return true
}
}
return false
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
// 用于仅依赖引用项完成续链场景的校验。
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
if reqBody == nil || len(callIDs) == 0 {
return false
}
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
referenceIDs := make(map[string]struct{})
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "item_reference" {
continue
}
idValue, _ := itemMap["id"].(string)
idValue = strings.TrimSpace(idValue)
if idValue == "" {
continue
}
referenceIDs[idValue] = struct{}{}
}
if len(referenceIDs) == 0 {
return false
}
for _, callID := range callIDs {
if _, ok := referenceIDs[callID]; !ok {
return false
}
}
return true
}
// inputHasType 判断 input 中是否存在指定类型的 item。
func inputHasType(reqBody map[string]any, want string) bool {
input, ok := reqBody["input"].([]any)
if !ok {
return false
}
for _, item := range input {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType == want {
return true
}
}
return false
}
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
return ok && strings.TrimSpace(stringValue) != ""
}
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
func hasToolsSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tools"]
if !exists || raw == nil {
return false
}
if tools, ok := raw.([]any); ok {
return len(tools) > 0
}
return false
}
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil
func hasToolChoiceSignal(reqBody map[string]any) bool {
raw, exists := reqBody["tool_choice"]
if !exists || raw == nil {
return false
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value) != ""
case map[string]any:
return len(value) > 0
default:
return false
}
}

View File

@@ -0,0 +1,98 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNeedsToolContinuationSignals(t *testing.T) {
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
cases := []struct {
name string
body map[string]any
want bool
}{
{name: "nil", body: nil, want: false},
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
})
}
}
func TestHasFunctionCallOutput(t *testing.T) {
// 仅当 input 中存在 function_call_output 才视为续链输出。
require.False(t, HasFunctionCallOutput(nil))
require.True(t, HasFunctionCallOutput(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutput(map[string]any{
"input": "text",
}))
}
func TestHasToolCallContext(t *testing.T) {
// tool_call/function_call 必须包含 call_id才能作为可关联上下文。
require.False(t, HasToolCallContext(nil))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
}))
require.True(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
}))
require.False(t, HasToolCallContext(map[string]any{
"input": []any{map[string]any{"type": "tool_call"}},
}))
}
func TestFunctionCallOutputCallIDs(t *testing.T) {
// 仅提取非空 call_id去重后返回。
require.Empty(t, FunctionCallOutputCallIDs(nil))
callIDs := FunctionCallOutputCallIDs(map[string]any{
"input": []any{
map[string]any{"type": "function_call_output", "call_id": "call_1"},
map[string]any{"type": "function_call_output", "call_id": ""},
map[string]any{"type": "function_call_output", "call_id": "call_1"},
},
})
require.ElementsMatch(t, []string{"call_1"}, callIDs)
}
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
require.False(t, HasFunctionCallOutputMissingCallID(nil))
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output"}},
}))
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
}))
}
func TestHasItemReferenceForCallIDs(t *testing.T) {
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
req := map[string]any{
"input": []any{
map[string]any{"type": "item_reference", "id": "call_1"},
map[string]any{"type": "item_reference", "id": "call_2"},
},
}
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
}

View File

@@ -0,0 +1,307 @@
package service
import (
"encoding/json"
"fmt"
"log"
"sync"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
var codexToolNameMapping = map[string]string{
"apply_patch": "edit",
"applyPatch": "edit",
"update_plan": "todowrite",
"updatePlan": "todowrite",
"read_plan": "todoread",
"readPlan": "todoread",
"search_files": "grep",
"searchFiles": "grep",
"list_files": "glob",
"listFiles": "glob",
"read_file": "read",
"readFile": "read",
"write_file": "write",
"writeFile": "write",
"execute_bash": "bash",
"executeBash": "bash",
"exec_bash": "bash",
"execBash": "bash",
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
type ToolCorrectionStats struct {
TotalCorrected int `json:"total_corrected"`
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
}
// CodexToolCorrector 处理 Codex 工具调用的自动修正
type CodexToolCorrector struct {
stats ToolCorrectionStats
mu sync.RWMutex
}
// NewCodexToolCorrector 创建新的工具修正器
func NewCodexToolCorrector() *CodexToolCorrector {
return &CodexToolCorrector{
stats: ToolCorrectionStats{
CorrectionsByTool: make(map[string]int),
},
}
}
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
// 返回修正后的数据和是否进行了修正
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
if data == "" || data == "\n" {
return data, false
}
// 尝试解析 JSON
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
// 不是有效的 JSON直接返回原数据
return data, false
}
corrected := false
// 处理 tool_calls 数组
if toolCalls, ok := payload["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
// 处理 function_call 对象
if functionCall, ok := payload["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
// 处理 delta.tool_calls
if delta, ok := payload["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
if choices, ok := payload["choices"].([]any); ok {
for _, choice := range choices {
if choiceMap, ok := choice.(map[string]any); ok {
// 处理 message 中的工具调用
if message, ok := choiceMap["message"].(map[string]any); ok {
if toolCalls, ok := message["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := message["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
// 处理 delta 中的工具调用
if delta, ok := choiceMap["delta"].(map[string]any); ok {
if toolCalls, ok := delta["tool_calls"].([]any); ok {
if c.correctToolCallsArray(toolCalls) {
corrected = true
}
}
if functionCall, ok := delta["function_call"].(map[string]any); ok {
if c.correctFunctionCall(functionCall) {
corrected = true
}
}
}
}
}
}
if !corrected {
return data, false
}
// 序列化回 JSON
correctedBytes, err := json.Marshal(payload)
if err != nil {
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
return data, false
}
return string(correctedBytes), true
}
// correctToolCallsArray 修正工具调用数组中的工具名称
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
corrected := false
for _, toolCall := range toolCalls {
if toolCallMap, ok := toolCall.(map[string]any); ok {
if function, ok := toolCallMap["function"].(map[string]any); ok {
if c.correctFunctionCall(function) {
corrected = true
}
}
}
}
return corrected
}
// correctFunctionCall 修正单个函数调用的工具名称和参数
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
name, ok := functionCall["name"].(string)
if !ok || name == "" {
return false
}
corrected := false
// 查找并修正工具名称
if correctName, found := codexToolNameMapping[name]; found {
functionCall["name"] = correctName
c.recordCorrection(name, correctName)
corrected = true
name = correctName // 使用修正后的名称进行参数修正
}
// 修正工具参数(基于工具名称)
if c.correctToolParameters(name, functionCall) {
corrected = true
}
return corrected
}
// correctToolParameters 修正工具参数以符合 OpenCode 规范
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
arguments, ok := functionCall["arguments"]
if !ok {
return false
}
// arguments 可能是字符串JSON或已解析的 map
var argsMap map[string]any
switch v := arguments.(type) {
case string:
// 解析 JSON 字符串
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
return false
}
case map[string]any:
argsMap = v
default:
return false
}
corrected := false
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// 移除 workdir 参数OpenCode 不支持)
if _, exists := argsMap["workdir"]; exists {
delete(argsMap, "workdir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
}
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
}
case "edit":
// OpenCode edit 使用 old_string/new_stringCodex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if _, exists := argsMap["file_path"]; !exists {
if path, exists := argsMap["path"]; exists {
argsMap["file_path"] = path
delete(argsMap, "path")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
}
}
}
// 如果修正了参数,需要重新序列化
if corrected {
if _, wasString := arguments.(string); wasString {
// 原本是字符串,序列化回字符串
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
functionCall["arguments"] = string(newArgsJSON)
}
} else {
// 原本是 map直接赋值
functionCall["arguments"] = argsMap
}
}
return corrected
}
// recordCorrection 记录一次工具名称修正
func (c *CodexToolCorrector) recordCorrection(from, to string) {
c.mu.Lock()
defer c.mu.Unlock()
c.stats.TotalCorrected++
key := fmt.Sprintf("%s->%s", from, to)
c.stats.CorrectionsByTool[key]++
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
from, to, c.stats.TotalCorrected)
}
// GetStats 获取工具修正统计信息
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
c.mu.RLock()
defer c.mu.RUnlock()
// 返回副本以避免并发问题
statsCopy := ToolCorrectionStats{
TotalCorrected: c.stats.TotalCorrected,
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
}
for k, v := range c.stats.CorrectionsByTool {
statsCopy.CorrectionsByTool[k] = v
}
return statsCopy
}
// ResetStats 重置统计信息
func (c *CodexToolCorrector) ResetStats() {
c.mu.Lock()
defer c.mu.Unlock()
c.stats.TotalCorrected = 0
c.stats.CorrectionsByTool = make(map[string]int)
}
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
func CorrectToolName(name string) (string, bool) {
if correctName, found := codexToolNameMapping[name]; found {
return correctName, true
}
return name, false
}
// GetToolNameMapping 获取工具名称映射表
func GetToolNameMapping() map[string]string {
// 返回副本以避免外部修改
mapping := make(map[string]string, len(codexToolNameMapping))
for k, v := range codexToolNameMapping {
mapping[k] = v
}
return mapping
}

View File

@@ -0,0 +1,503 @@
package service
import (
"encoding/json"
"testing"
)
func TestCorrectToolCallsInSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
tests := []struct {
name string
input string
expectCorrected bool
checkFunc func(t *testing.T, result string)
}{
{
name: "empty string",
input: "",
expectCorrected: false,
},
{
name: "newline only",
input: "\n",
expectCorrected: false,
},
{
name: "invalid json",
input: "not a json",
expectCorrected: false,
},
{
name: "correct apply_patch in tool_calls",
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in result")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
}
},
},
{
name: "correct update_plan in function_call",
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
functionCall, ok := payload["function_call"].(map[string]any)
if !ok {
t.Fatal("Invalid function_call format")
}
if functionCall["name"] != "todowrite" {
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
}
},
},
{
name: "correct search_files in delta.tool_calls",
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
delta, ok := payload["delta"].(map[string]any)
if !ok {
t.Fatal("Invalid delta format")
}
toolCalls, ok := delta["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in delta")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "grep" {
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
}
},
},
{
name: "correct list_files in choices.message.tool_calls",
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
t.Fatal("No choices found in result")
}
choice, ok := choices[0].(map[string]any)
if !ok {
t.Fatal("Invalid choice format")
}
message, ok := choice["message"].(map[string]any)
if !ok {
t.Fatal("Invalid message format")
}
toolCalls, ok := message["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in message")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "glob" {
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
}
},
},
{
name: "no correction needed",
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
expectCorrected: false,
},
{
name: "correct multiple tool calls",
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) < 2 {
t.Fatal("Expected at least 2 tool_calls")
}
toolCall1, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid first tool_call format")
}
func1, ok := toolCall1["function"].(map[string]any)
if !ok {
t.Fatal("Invalid first function format")
}
if func1["name"] != "edit" {
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
}
toolCall2, ok := toolCalls[1].(map[string]any)
if !ok {
t.Fatal("Invalid second tool_call format")
}
func2, ok := toolCall2["function"].(map[string]any)
if !ok {
t.Fatal("Invalid second function format")
}
if func2["name"] != "read" {
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
}
},
},
{
name: "camelCase format - applyPatch",
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
expectCorrected: true,
checkFunc: func(t *testing.T, result string) {
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
toolCalls, ok := payload["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in result")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
functionCall, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if functionCall["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
if corrected != tt.expectCorrected {
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
}
if !corrected && result != tt.input {
t.Errorf("Expected unchanged result when not corrected")
}
if tt.checkFunc != nil {
tt.checkFunc(t, result)
}
})
}
}
func TestCorrectToolName(t *testing.T) {
tests := []struct {
input string
expected string
corrected bool
}{
{"apply_patch", "edit", true},
{"applyPatch", "edit", true},
{"update_plan", "todowrite", true},
{"updatePlan", "todowrite", true},
{"read_plan", "todoread", true},
{"readPlan", "todoread", true},
{"search_files", "grep", true},
{"searchFiles", "grep", true},
{"list_files", "glob", true},
{"listFiles", "glob", true},
{"read_file", "read", true},
{"readFile", "read", true},
{"write_file", "write", true},
{"writeFile", "write", true},
{"execute_bash", "bash", true},
{"executeBash", "bash", true},
{"exec_bash", "bash", true},
{"execBash", "bash", true},
{"unknown_tool", "unknown_tool", false},
{"read", "read", false},
{"edit", "edit", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, corrected := CorrectToolName(tt.input)
if corrected != tt.corrected {
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
}
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestGetToolNameMapping(t *testing.T) {
mapping := GetToolNameMapping()
expectedMappings := map[string]string{
"apply_patch": "edit",
"update_plan": "todowrite",
"read_plan": "todoread",
"search_files": "grep",
"list_files": "glob",
}
for from, to := range expectedMappings {
if mapping[from] != to {
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
}
}
mapping["test_tool"] = "test_value"
newMapping := GetToolNameMapping()
if _, exists := newMapping["test_tool"]; exists {
t.Error("Modifications to returned mapping should not affect original")
}
}
func TestCorrectorStats(t *testing.T) {
corrector := NewCodexToolCorrector()
stats := corrector.GetStats()
if stats.TotalCorrected != 0 {
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
}
if len(stats.CorrectionsByTool) != 0 {
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
}
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
stats = corrector.GetStats()
if stats.TotalCorrected != 3 {
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
}
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
}
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
}
corrector.ResetStats()
stats = corrector.GetStats()
if stats.TotalCorrected != 0 {
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
}
if len(stats.CorrectionsByTool) != 0 {
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
}
}
func TestComplexSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
input := `{
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-5.1-codex",
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
"function": {
"name": "apply_patch",
"arguments": "{\"file\":\"test.go\"}"
}
}
]
},
"finish_reason": null
}
]
}`
result, corrected := corrector.CorrectToolCallsInSSEData(input)
if !corrected {
t.Error("Expected data to be corrected")
}
var payload map[string]any
if err := json.Unmarshal([]byte(result), &payload); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
choices, ok := payload["choices"].([]any)
if !ok || len(choices) == 0 {
t.Fatal("No choices found in result")
}
choice, ok := choices[0].(map[string]any)
if !ok {
t.Fatal("Invalid choice format")
}
delta, ok := choice["delta"].(map[string]any)
if !ok {
t.Fatal("Invalid delta format")
}
toolCalls, ok := delta["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("No tool_calls found in delta")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("Invalid tool_call format")
}
function, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("Invalid function format")
}
if function["name"] != "edit" {
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
}
}
// TestCorrectToolParameters 测试工具参数修正
func TestCorrectToolParameters(t *testing.T) {
corrector := NewCodexToolCorrector()
tests := []struct {
name string
input string
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
}{
{
name: "remove workdir from bash tool",
input: `{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
}
}]
}`,
expected: map[string]bool{
"command": true,
"workdir": false,
},
},
{
name: "rename path to file_path in edit tool",
input: `{
"tool_calls": [{
"function": {
"name": "apply_patch",
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
}
}]
}`,
expected: map[string]bool{
"file_path": true,
"path": false,
"old_string": true,
"new_string": true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
if !changed {
t.Error("expected data to be corrected")
}
// 解析修正后的数据
var result map[string]any
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
t.Fatalf("failed to parse corrected data: %v", err)
}
// 检查工具调用
toolCalls, ok := result["tool_calls"].([]any)
if !ok || len(toolCalls) == 0 {
t.Fatal("no tool_calls found in corrected data")
}
toolCall, ok := toolCalls[0].(map[string]any)
if !ok {
t.Fatal("invalid tool_call structure")
}
function, ok := toolCall["function"].(map[string]any)
if !ok {
t.Fatal("no function found in tool_call")
}
argumentsStr, ok := function["arguments"].(string)
if !ok {
t.Fatal("arguments is not a string")
}
var args map[string]any
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
t.Fatalf("failed to parse arguments: %v", err)
}
// 验证期望的参数
for param, shouldExist := range tt.expected {
_, exists := args[param]
if shouldExist && !exists {
t.Errorf("expected parameter %q to exist, but it doesn't", param)
}
if !shouldExist && exists {
t.Errorf("expected parameter %q to not exist, but it does", param)
}
}
})
}
}

View File

@@ -0,0 +1,194 @@
package service
import (
"context"
"errors"
"time"
)
// GetAccountAvailabilityStats returns current account availability stats.
//
// Query-level filtering is intentionally limited to platform/group to match the dashboard scope.
func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFilter string, groupIDFilter *int64) (
map[string]*PlatformAvailability,
map[int64]*GroupAvailability,
map[int64]*AccountAvailability,
*time.Time,
error,
) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, nil, nil, nil, err
}
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
if err != nil {
return nil, nil, nil, nil, err
}
if groupIDFilter != nil && *groupIDFilter > 0 {
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
for _, grp := range acc.Groups {
if grp != nil && grp.ID == *groupIDFilter {
filtered = append(filtered, acc)
break
}
}
}
accounts = filtered
}
now := time.Now()
collectedAt := now
platform := make(map[string]*PlatformAvailability)
group := make(map[int64]*GroupAvailability)
account := make(map[int64]*AccountAvailability)
for _, acc := range accounts {
if acc.ID <= 0 {
continue
}
isTempUnsched := false
if acc.TempUnschedulableUntil != nil && now.Before(*acc.TempUnschedulableUntil) {
isTempUnsched = true
}
isRateLimited := acc.RateLimitResetAt != nil && now.Before(*acc.RateLimitResetAt)
isOverloaded := acc.OverloadUntil != nil && now.Before(*acc.OverloadUntil)
hasError := acc.Status == StatusError
// Normalize exclusive status flags so the UI doesn't show conflicting badges.
if hasError {
isRateLimited = false
isOverloaded = false
}
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
platform[acc.Platform] = &PlatformAvailability{
Platform: acc.Platform,
}
}
p := platform[acc.Platform]
p.TotalAccounts++
if isAvailable {
p.AvailableCount++
}
if isRateLimited {
p.RateLimitCount++
}
if hasError {
p.ErrorCount++
}
}
for _, grp := range acc.Groups {
if grp == nil || grp.ID <= 0 {
continue
}
if _, ok := group[grp.ID]; !ok {
group[grp.ID] = &GroupAvailability{
GroupID: grp.ID,
GroupName: grp.Name,
Platform: grp.Platform,
}
}
g := group[grp.ID]
g.TotalAccounts++
if isAvailable {
g.AvailableCount++
}
if isRateLimited {
g.RateLimitCount++
}
if hasError {
g.ErrorCount++
}
}
displayGroupID := int64(0)
displayGroupName := ""
if len(acc.Groups) > 0 && acc.Groups[0] != nil {
displayGroupID = acc.Groups[0].ID
displayGroupName = acc.Groups[0].Name
}
item := &AccountAvailability{
AccountID: acc.ID,
AccountName: acc.Name,
Platform: acc.Platform,
GroupID: displayGroupID,
GroupName: displayGroupName,
Status: acc.Status,
IsAvailable: isAvailable,
IsRateLimited: isRateLimited,
IsOverloaded: isOverloaded,
HasError: hasError,
ErrorMessage: acc.ErrorMessage,
}
if isRateLimited && acc.RateLimitResetAt != nil {
item.RateLimitResetAt = acc.RateLimitResetAt
remainingSec := int64(time.Until(*acc.RateLimitResetAt).Seconds())
if remainingSec > 0 {
item.RateLimitRemainingSec = &remainingSec
}
}
if isOverloaded && acc.OverloadUntil != nil {
item.OverloadUntil = acc.OverloadUntil
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
if remainingSec > 0 {
item.OverloadRemainingSec = &remainingSec
}
}
if isTempUnsched && acc.TempUnschedulableUntil != nil {
item.TempUnschedulableUntil = acc.TempUnschedulableUntil
}
account[acc.ID] = item
}
return platform, group, account, &collectedAt, nil
}
type OpsAccountAvailability struct {
Group *GroupAvailability
Accounts map[int64]*AccountAvailability
CollectedAt *time.Time
}
func (s *OpsService) GetAccountAvailability(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error) {
if s == nil {
return nil, errors.New("ops service is nil")
}
if s.getAccountAvailability != nil {
return s.getAccountAvailability(ctx, platformFilter, groupIDFilter)
}
_, groupStats, accountStats, collectedAt, err := s.GetAccountAvailabilityStats(ctx, platformFilter, groupIDFilter)
if err != nil {
return nil, err
}
var group *GroupAvailability
if groupIDFilter != nil && *groupIDFilter > 0 {
group = groupStats[*groupIDFilter]
}
if accountStats == nil {
accountStats = map[int64]*AccountAvailability{}
}
return &OpsAccountAvailability{
Group: group,
Accounts: accountStats,
CollectedAt: collectedAt,
}, nil
}

View File

@@ -0,0 +1,46 @@
package service
import (
"context"
"database/sql"
"hash/fnv"
"time"
)
func hashAdvisoryLockID(key string) int64 {
h := fnv.New64a()
_, _ = h.Write([]byte(key))
return int64(h.Sum64())
}
func tryAcquireDBAdvisoryLock(ctx context.Context, db *sql.DB, lockID int64) (func(), bool) {
if db == nil {
return nil, false
}
if ctx == nil {
ctx = context.Background()
}
conn, err := db.Conn(ctx)
if err != nil {
return nil, false
}
acquired := false
if err := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired); err != nil {
_ = conn.Close()
return nil, false
}
if !acquired {
_ = conn.Close()
return nil, false
}
release := func() {
unlockCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, _ = conn.ExecContext(unlockCtx, "SELECT pg_advisory_unlock($1)", lockID)
_ = conn.Close()
}
return release, true
}

View File

@@ -0,0 +1,448 @@
package service
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
const (
opsAggHourlyJobName = "ops_preaggregation_hourly"
opsAggDailyJobName = "ops_preaggregation_daily"
opsAggHourlyInterval = 10 * time.Minute
opsAggDailyInterval = 1 * time.Hour
// Keep in sync with ops retention target (vNext default 30d).
opsAggBackfillWindow = 30 * 24 * time.Hour
// Recompute overlap to absorb late-arriving rows near boundaries.
opsAggHourlyOverlap = 2 * time.Hour
opsAggDailyOverlap = 48 * time.Hour
opsAggHourlyChunk = 24 * time.Hour
opsAggDailyChunk = 7 * 24 * time.Hour
// Delay around boundaries (e.g. 10:00..10:05) to avoid aggregating buckets
// that may still receive late inserts.
opsAggSafeDelay = 5 * time.Minute
opsAggMaxQueryTimeout = 3 * time.Second
opsAggHourlyTimeout = 5 * time.Minute
opsAggDailyTimeout = 2 * time.Minute
opsAggHourlyLeaderLockKey = "ops:aggregation:hourly:leader"
opsAggDailyLeaderLockKey = "ops:aggregation:daily:leader"
opsAggHourlyLeaderLockTTL = 15 * time.Minute
opsAggDailyLeaderLockTTL = 10 * time.Minute
)
// OpsAggregationService periodically backfills ops_metrics_hourly / ops_metrics_daily
// for stable long-window dashboard queries.
//
// It is safe to run in multi-replica deployments when Redis is available (leader lock).
type OpsAggregationService struct {
opsRepo OpsRepository
settingRepo SettingRepository
cfg *config.Config
db *sql.DB
redisClient *redis.Client
instanceID string
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
hourlyMu sync.Mutex
dailyMu sync.Mutex
skipLogMu sync.Mutex
skipLogAt time.Time
}
func NewOpsAggregationService(
opsRepo OpsRepository,
settingRepo SettingRepository,
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
) *OpsAggregationService {
return &OpsAggregationService{
opsRepo: opsRepo,
settingRepo: settingRepo,
cfg: cfg,
db: db,
redisClient: redisClient,
instanceID: uuid.NewString(),
}
}
func (s *OpsAggregationService) Start() {
if s == nil {
return
}
s.startOnce.Do(func() {
if s.stopCh == nil {
s.stopCh = make(chan struct{})
}
go s.hourlyLoop()
go s.dailyLoop()
})
}
func (s *OpsAggregationService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stopCh != nil {
close(s.stopCh)
}
})
}
func (s *OpsAggregationService) hourlyLoop() {
// First run immediately.
s.aggregateHourly()
ticker := time.NewTicker(opsAggHourlyInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.aggregateHourly()
case <-s.stopCh:
return
}
}
}
func (s *OpsAggregationService) dailyLoop() {
// First run immediately.
s.aggregateDaily()
ticker := time.NewTicker(opsAggDailyInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.aggregateDaily()
case <-s.stopCh:
return
}
}
}
func (s *OpsAggregationService) aggregateHourly() {
if s == nil || s.opsRepo == nil {
return
}
if s.cfg != nil {
if !s.cfg.Ops.Enabled {
return
}
if !s.cfg.Ops.Aggregation.Enabled {
return
}
}
ctx, cancel := context.WithTimeout(context.Background(), opsAggHourlyTimeout)
defer cancel()
if !s.isMonitoringEnabled(ctx) {
return
}
release, ok := s.tryAcquireLeaderLock(ctx, opsAggHourlyLeaderLockKey, opsAggHourlyLeaderLockTTL, "[OpsAggregation][hourly]")
if !ok {
return
}
if release != nil {
defer release()
}
s.hourlyMu.Lock()
defer s.hourlyMu.Unlock()
startedAt := time.Now().UTC()
runAt := startedAt
// Aggregate stable full hours only.
end := utcFloorToHour(time.Now().UTC().Add(-opsAggSafeDelay))
start := end.Add(-opsAggBackfillWindow)
// Resume from the latest bucket with overlap.
{
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax)
cancelMax()
if err != nil {
log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err)
} else if ok {
candidate := latest.Add(-opsAggHourlyOverlap)
if candidate.After(start) {
start = candidate
}
}
}
start = utcFloorToHour(start)
if !start.Before(end) {
return
}
var aggErr error
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggHourlyChunk) {
chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end)
if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil {
aggErr = err
log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
break
}
}
finishedAt := time.Now().UTC()
durationMs := finishedAt.Sub(startedAt).Milliseconds()
dur := durationMs
if aggErr != nil {
msg := truncateString(aggErr.Error(), 2048)
errAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggHourlyJobName,
LastRunAt: &runAt,
LastErrorAt: &errAt,
LastError: &msg,
LastDurationMs: &dur,
})
return
}
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggHourlyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
LastResult: &result,
})
}
func (s *OpsAggregationService) aggregateDaily() {
if s == nil || s.opsRepo == nil {
return
}
if s.cfg != nil {
if !s.cfg.Ops.Enabled {
return
}
if !s.cfg.Ops.Aggregation.Enabled {
return
}
}
ctx, cancel := context.WithTimeout(context.Background(), opsAggDailyTimeout)
defer cancel()
if !s.isMonitoringEnabled(ctx) {
return
}
release, ok := s.tryAcquireLeaderLock(ctx, opsAggDailyLeaderLockKey, opsAggDailyLeaderLockTTL, "[OpsAggregation][daily]")
if !ok {
return
}
if release != nil {
defer release()
}
s.dailyMu.Lock()
defer s.dailyMu.Unlock()
startedAt := time.Now().UTC()
runAt := startedAt
end := utcFloorToDay(time.Now().UTC())
start := end.Add(-opsAggBackfillWindow)
{
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax)
cancelMax()
if err != nil {
log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err)
} else if ok {
candidate := latest.Add(-opsAggDailyOverlap)
if candidate.After(start) {
start = candidate
}
}
}
start = utcFloorToDay(start)
if !start.Before(end) {
return
}
var aggErr error
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggDailyChunk) {
chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end)
if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil {
aggErr = err
log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
break
}
}
finishedAt := time.Now().UTC()
durationMs := finishedAt.Sub(startedAt).Milliseconds()
dur := durationMs
if aggErr != nil {
msg := truncateString(aggErr.Error(), 2048)
errAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggDailyJobName,
LastRunAt: &runAt,
LastErrorAt: &errAt,
LastError: &msg,
LastDurationMs: &dur,
})
return
}
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggDailyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
LastResult: &result,
})
}
func (s *OpsAggregationService) isMonitoringEnabled(ctx context.Context) bool {
if s == nil {
return false
}
if s.cfg != nil && !s.cfg.Ops.Enabled {
return false
}
if s.settingRepo == nil {
return true
}
if ctx == nil {
ctx = context.Background()
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return true
}
return true
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
return false
default:
return true
}
}
var opsAggReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
func (s *OpsAggregationService) tryAcquireLeaderLock(ctx context.Context, key string, ttl time.Duration, logPrefix string) (func(), bool) {
if s == nil {
return nil, false
}
if ctx == nil {
ctx = context.Background()
}
// Prefer Redis leader lock when available (multi-instance), but avoid stampeding
// the DB when Redis is flaky by falling back to a DB advisory lock.
if s.redisClient != nil {
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
if err == nil {
if !ok {
s.maybeLogSkip(logPrefix)
return nil, false
}
release := func() {
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, _ = opsAggReleaseScript.Run(ctx2, s.redisClient, []string{key}, s.instanceID).Result()
}
return release, true
}
// Redis error: fall through to DB advisory lock.
}
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
if !ok {
s.maybeLogSkip(logPrefix)
return nil, false
}
return release, true
}
func (s *OpsAggregationService) maybeLogSkip(prefix string) {
s.skipLogMu.Lock()
defer s.skipLogMu.Unlock()
now := time.Now()
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < time.Minute {
return
}
s.skipLogAt = now
if prefix == "" {
prefix = "[OpsAggregation]"
}
log.Printf("%s leader lock held by another instance; skipping", prefix)
}
func utcFloorToHour(t time.Time) time.Time {
return t.UTC().Truncate(time.Hour)
}
func utcFloorToDay(t time.Time) time.Time {
u := t.UTC()
y, m, d := u.Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
}
func minTime(a, b time.Time) time.Time {
if a.Before(b) {
return a
}
return b
}

View File

@@ -0,0 +1,944 @@
package service
import (
"context"
"fmt"
"log"
"math"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
const (
opsAlertEvaluatorJobName = "ops_alert_evaluator"
opsAlertEvaluatorTimeout = 45 * time.Second
opsAlertEvaluatorLeaderLockKey = "ops:alert:evaluator:leader"
opsAlertEvaluatorLeaderLockTTL = 90 * time.Second
opsAlertEvaluatorSkipLogInterval = 1 * time.Minute
)
var opsAlertEvaluatorReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
type OpsAlertEvaluatorService struct {
opsService *OpsService
opsRepo OpsRepository
emailService *EmailService
redisClient *redis.Client
cfg *config.Config
instanceID string
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
wg sync.WaitGroup
mu sync.Mutex
ruleStates map[int64]*opsAlertRuleState
emailLimiter *slidingWindowLimiter
skipLogMu sync.Mutex
skipLogAt time.Time
warnNoRedisOnce sync.Once
}
type opsAlertRuleState struct {
LastEvaluatedAt time.Time
ConsecutiveBreaches int
}
func NewOpsAlertEvaluatorService(
opsService *OpsService,
opsRepo OpsRepository,
emailService *EmailService,
redisClient *redis.Client,
cfg *config.Config,
) *OpsAlertEvaluatorService {
return &OpsAlertEvaluatorService{
opsService: opsService,
opsRepo: opsRepo,
emailService: emailService,
redisClient: redisClient,
cfg: cfg,
instanceID: uuid.NewString(),
ruleStates: map[int64]*opsAlertRuleState{},
emailLimiter: newSlidingWindowLimiter(0, time.Hour),
}
}
func (s *OpsAlertEvaluatorService) Start() {
if s == nil {
return
}
s.startOnce.Do(func() {
if s.stopCh == nil {
s.stopCh = make(chan struct{})
}
go s.run()
})
}
func (s *OpsAlertEvaluatorService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stopCh != nil {
close(s.stopCh)
}
})
s.wg.Wait()
}
func (s *OpsAlertEvaluatorService) run() {
s.wg.Add(1)
defer s.wg.Done()
// Start immediately to produce early feedback in ops dashboard.
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-timer.C:
interval := s.getInterval()
s.evaluateOnce(interval)
timer.Reset(interval)
case <-s.stopCh:
return
}
}
}
func (s *OpsAlertEvaluatorService) getInterval() time.Duration {
// Default.
interval := 60 * time.Second
if s == nil || s.opsService == nil {
return interval
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cfg, err := s.opsService.GetOpsAlertRuntimeSettings(ctx)
if err != nil || cfg == nil {
return interval
}
if cfg.EvaluationIntervalSeconds <= 0 {
return interval
}
if cfg.EvaluationIntervalSeconds < 1 {
return interval
}
if cfg.EvaluationIntervalSeconds > int((24 * time.Hour).Seconds()) {
return interval
}
return time.Duration(cfg.EvaluationIntervalSeconds) * time.Second
}
func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
if s == nil || s.opsRepo == nil {
return
}
if s.cfg != nil && !s.cfg.Ops.Enabled {
return
}
ctx, cancel := context.WithTimeout(context.Background(), opsAlertEvaluatorTimeout)
defer cancel()
if s.opsService != nil && !s.opsService.IsMonitoringEnabled(ctx) {
return
}
runtimeCfg := defaultOpsAlertRuntimeSettings()
if s.opsService != nil {
if loaded, err := s.opsService.GetOpsAlertRuntimeSettings(ctx); err == nil && loaded != nil {
runtimeCfg = loaded
}
}
release, ok := s.tryAcquireLeaderLock(ctx, runtimeCfg.DistributedLock)
if !ok {
return
}
if release != nil {
defer release()
}
startedAt := time.Now().UTC()
runAt := startedAt
rules, err := s.opsRepo.ListAlertRules(ctx)
if err != nil {
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
log.Printf("[OpsAlertEvaluator] list rules failed: %v", err)
return
}
rulesTotal := len(rules)
rulesEnabled := 0
rulesEvaluated := 0
eventsCreated := 0
eventsResolved := 0
emailsSent := 0
now := time.Now().UTC()
safeEnd := now.Truncate(time.Minute)
if safeEnd.IsZero() {
safeEnd = now
}
systemMetrics, _ := s.opsRepo.GetLatestSystemMetrics(ctx, 1)
// Cleanup stale state for removed rules.
s.pruneRuleStates(rules)
for _, rule := range rules {
if rule == nil || !rule.Enabled || rule.ID <= 0 {
continue
}
rulesEnabled++
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
windowMinutes := rule.WindowMinutes
if windowMinutes <= 0 {
windowMinutes = 1
}
windowStart := safeEnd.Add(-time.Duration(windowMinutes) * time.Minute)
windowEnd := safeEnd
metricValue, ok := s.computeRuleMetric(ctx, rule, systemMetrics, windowStart, windowEnd, scopePlatform, scopeGroupID)
if !ok {
s.resetRuleState(rule.ID, now)
continue
}
rulesEvaluated++
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
consecutive := s.updateRuleBreaches(rule.ID, now, interval, breachedNow)
activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
continue
}
if breachedNow && consecutive >= required {
if activeEvent != nil {
continue
}
// Scoped silencing: if a matching silence exists, skip creating a firing event.
if s.opsService != nil {
platform := strings.TrimSpace(scopePlatform)
region := scopeRegion
if platform != "" {
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
continue
}
}
}
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
continue
}
if latestEvent != nil && rule.CooldownMinutes > 0 {
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
if now.Sub(latestEvent.FiredAt) < cooldown {
continue
}
}
firedEvent := &OpsAlertEvent{
RuleID: rule.ID,
Severity: strings.TrimSpace(rule.Severity),
Status: OpsAlertStatusFiring,
Title: fmt.Sprintf("%s: %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name)),
Description: buildOpsAlertDescription(rule, metricValue, windowMinutes, scopePlatform, scopeGroupID),
MetricValue: float64Ptr(metricValue),
ThresholdValue: float64Ptr(rule.Threshold),
Dimensions: buildOpsAlertDimensions(scopePlatform, scopeGroupID),
FiredAt: now,
CreatedAt: now,
}
created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent)
if err != nil {
log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
continue
}
eventsCreated++
if created != nil && created.ID > 0 {
if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
emailsSent++
}
}
continue
}
// Not breached: resolve active event if present.
if activeEvent != nil {
resolvedAt := now
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
} else {
eventsResolved++
}
}
}
result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
}
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
s.mu.Lock()
defer s.mu.Unlock()
live := map[int64]struct{}{}
for _, r := range rules {
if r != nil && r.ID > 0 {
live[r.ID] = struct{}{}
}
}
for id := range s.ruleStates {
if _, ok := live[id]; !ok {
delete(s.ruleStates, id)
}
}
}
func (s *OpsAlertEvaluatorService) resetRuleState(ruleID int64, now time.Time) {
if ruleID <= 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()
state, ok := s.ruleStates[ruleID]
if !ok {
state = &opsAlertRuleState{}
s.ruleStates[ruleID] = state
}
state.LastEvaluatedAt = now
state.ConsecutiveBreaches = 0
}
func (s *OpsAlertEvaluatorService) updateRuleBreaches(ruleID int64, now time.Time, interval time.Duration, breached bool) int {
if ruleID <= 0 {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
state, ok := s.ruleStates[ruleID]
if !ok {
state = &opsAlertRuleState{}
s.ruleStates[ruleID] = state
}
if !state.LastEvaluatedAt.IsZero() && interval > 0 {
if now.Sub(state.LastEvaluatedAt) > interval*2 {
state.ConsecutiveBreaches = 0
}
}
state.LastEvaluatedAt = now
if breached {
state.ConsecutiveBreaches++
} else {
state.ConsecutiveBreaches = 0
}
return state.ConsecutiveBreaches
}
func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int {
if sustainedMinutes <= 0 {
return 1
}
if interval <= 0 {
return sustainedMinutes
}
required := int(math.Ceil(float64(sustainedMinutes*60) / interval.Seconds()))
if required < 1 {
return 1
}
return required
}
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
if filters == nil {
return "", nil, nil
}
if v, ok := filters["platform"]; ok {
if s, ok := v.(string); ok {
platform = strings.TrimSpace(s)
}
}
if v, ok := filters["group_id"]; ok {
switch t := v.(type) {
case float64:
if t > 0 {
id := int64(t)
groupID = &id
}
case int64:
if t > 0 {
id := t
groupID = &id
}
case int:
if t > 0 {
id := int64(t)
groupID = &id
}
case string:
n, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64)
if err == nil && n > 0 {
groupID = &n
}
}
}
if v, ok := filters["region"]; ok {
if s, ok := v.(string); ok {
vv := strings.TrimSpace(s)
if vv != "" {
region = &vv
}
}
}
return platform, groupID, region
}
func (s *OpsAlertEvaluatorService) computeRuleMetric(
ctx context.Context,
rule *OpsAlertRule,
systemMetrics *OpsSystemMetricsSnapshot,
start time.Time,
end time.Time,
platform string,
groupID *int64,
) (float64, bool) {
if rule == nil {
return 0, false
}
switch strings.TrimSpace(rule.MetricType) {
case "cpu_usage_percent":
if systemMetrics != nil && systemMetrics.CPUUsagePercent != nil {
return *systemMetrics.CPUUsagePercent, true
}
return 0, false
case "memory_usage_percent":
if systemMetrics != nil && systemMetrics.MemoryUsagePercent != nil {
return *systemMetrics.MemoryUsagePercent, true
}
return 0, false
case "concurrency_queue_depth":
if systemMetrics != nil && systemMetrics.ConcurrencyQueueDepth != nil {
return float64(*systemMetrics.ConcurrencyQueueDepth), true
}
return 0, false
case "group_available_accounts":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
if availability.Group == nil {
return 0, true
}
return float64(availability.Group.AvailableCount), true
case "group_available_ratio":
if groupID == nil || *groupID <= 0 {
return 0, false
}
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return computeGroupAvailableRatio(availability.Group), true
case "account_rate_limited_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.IsRateLimited
})), true
case "account_error_count":
if s == nil || s.opsService == nil {
return 0, false
}
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
if err != nil || availability == nil {
return 0, false
}
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})), true
}
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
StartTime: start,
EndTime: end,
Platform: platform,
GroupID: groupID,
QueryMode: OpsQueryModeRaw,
})
if err != nil {
return 0, false
}
if overview == nil {
return 0, false
}
switch strings.TrimSpace(rule.MetricType) {
case "success_rate":
if overview.RequestCountSLA <= 0 {
return 0, false
}
return overview.SLA * 100, true
case "error_rate":
if overview.RequestCountSLA <= 0 {
return 0, false
}
return overview.ErrorRate * 100, true
case "upstream_error_rate":
if overview.RequestCountSLA <= 0 {
return 0, false
}
return overview.UpstreamErrorRate * 100, true
default:
return 0, false
}
}
func compareMetric(value float64, operator string, threshold float64) bool {
switch strings.TrimSpace(operator) {
case ">":
return value > threshold
case ">=":
return value >= threshold
case "<":
return value < threshold
case "<=":
return value <= threshold
case "==":
return value == threshold
case "!=":
return value != threshold
default:
return false
}
}
func buildOpsAlertDimensions(platform string, groupID *int64) map[string]any {
dims := map[string]any{}
if strings.TrimSpace(platform) != "" {
dims["platform"] = strings.TrimSpace(platform)
}
if groupID != nil && *groupID > 0 {
dims["group_id"] = *groupID
}
if len(dims) == 0 {
return nil
}
return dims
}
func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes int, platform string, groupID *int64) string {
if rule == nil {
return ""
}
scope := "overall"
if strings.TrimSpace(platform) != "" {
scope = fmt.Sprintf("platform=%s", strings.TrimSpace(platform))
}
if groupID != nil && *groupID > 0 {
scope = fmt.Sprintf("%s group_id=%d", scope, *groupID)
}
if windowMinutes <= 0 {
windowMinutes = 1
}
return fmt.Sprintf("%s %s %.2f (current %.2f) over last %dm (%s)",
strings.TrimSpace(rule.MetricType),
strings.TrimSpace(rule.Operator),
rule.Threshold,
value,
windowMinutes,
strings.TrimSpace(scope),
)
}
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
return false
}
if event.EmailSent {
return false
}
if !rule.NotifyEmail {
return false
}
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
return false
}
if len(emailCfg.Alert.Recipients) == 0 {
return false
}
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
return false
}
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
return false
}
}
// Apply/update rate limiter.
s.emailLimiter.SetLimit(emailCfg.Alert.RateLimitPerHour)
subject := fmt.Sprintf("[Ops Alert][%s] %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name))
body := buildOpsAlertEmailBody(rule, event)
anySent := false
for _, to := range emailCfg.Alert.Recipients {
addr := strings.TrimSpace(to)
if addr == "" {
continue
}
if !s.emailLimiter.Allow(time.Now().UTC()) {
continue
}
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
// Ignore per-recipient failures; continue best-effort.
continue
}
anySent = true
}
if anySent {
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
}
return anySent
}
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
if rule == nil || event == nil {
return ""
}
metric := strings.TrimSpace(rule.MetricType)
value := "-"
threshold := fmt.Sprintf("%.2f", rule.Threshold)
if event.MetricValue != nil {
value = fmt.Sprintf("%.2f", *event.MetricValue)
}
if event.ThresholdValue != nil {
threshold = fmt.Sprintf("%.2f", *event.ThresholdValue)
}
return fmt.Sprintf(`
<h2>Ops Alert</h2>
<p><b>Rule</b>: %s</p>
<p><b>Severity</b>: %s</p>
<p><b>Status</b>: %s</p>
<p><b>Metric</b>: %s %s %s</p>
<p><b>Fired at</b>: %s</p>
<p><b>Description</b>: %s</p>
`,
htmlEscape(rule.Name),
htmlEscape(rule.Severity),
htmlEscape(event.Status),
htmlEscape(metric),
htmlEscape(rule.Operator),
htmlEscape(fmt.Sprintf("%s (threshold %s)", value, threshold)),
event.FiredAt.Format(time.RFC3339),
htmlEscape(event.Description),
)
}
func shouldSendOpsAlertEmailByMinSeverity(minSeverity string, ruleSeverity string) bool {
minSeverity = strings.ToLower(strings.TrimSpace(minSeverity))
if minSeverity == "" {
return true
}
eventLevel := opsEmailSeverityForOps(ruleSeverity)
minLevel := strings.ToLower(minSeverity)
rank := func(level string) int {
switch level {
case "critical":
return 3
case "warning":
return 2
case "info":
return 1
default:
return 0
}
}
return rank(eventLevel) >= rank(minLevel)
}
func opsEmailSeverityForOps(severity string) string {
switch strings.ToUpper(strings.TrimSpace(severity)) {
case "P0":
return "critical"
case "P1":
return "warning"
default:
return "info"
}
}
func isOpsAlertSilenced(now time.Time, rule *OpsAlertRule, event *OpsAlertEvent, silencing OpsAlertSilencingSettings) bool {
if !silencing.Enabled {
return false
}
if now.IsZero() {
now = time.Now().UTC()
}
if strings.TrimSpace(silencing.GlobalUntilRFC3339) != "" {
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(silencing.GlobalUntilRFC3339)); err == nil {
if now.Before(t) {
return true
}
}
}
for _, entry := range silencing.Entries {
untilRaw := strings.TrimSpace(entry.UntilRFC3339)
if untilRaw == "" {
continue
}
until, err := time.Parse(time.RFC3339, untilRaw)
if err != nil {
continue
}
if now.After(until) {
continue
}
if entry.RuleID != nil && rule != nil && rule.ID > 0 && *entry.RuleID != rule.ID {
continue
}
if len(entry.Severities) > 0 {
match := false
for _, s := range entry.Severities {
if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(event.Severity)) || strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(rule.Severity)) {
match = true
break
}
}
if !match {
continue
}
}
return true
}
return false
}
func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, lock OpsDistributedLockSettings) (func(), bool) {
if !lock.Enabled {
return nil, true
}
if s.redisClient == nil {
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock")
})
return nil, true
}
key := strings.TrimSpace(lock.Key)
if key == "" {
key = opsAlertEvaluatorLeaderLockKey
}
ttl := time.Duration(lock.TTLSeconds) * time.Second
if ttl <= 0 {
ttl = opsAlertEvaluatorLeaderLockTTL
}
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
if err != nil {
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
// Single-node deployments can disable the distributed lock via runtime settings.
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
})
return nil, false
}
if !ok {
s.maybeLogSkip(key)
return nil, false
}
return func() {
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
}, true
}
func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
s.skipLogMu.Lock()
defer s.skipLogMu.Unlock()
now := time.Now()
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < opsAlertEvaluatorSkipLogInterval {
return
}
s.skipLogAt = now
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
}
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsRepo == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg := strings.TrimSpace(result)
if msg == "" {
msg = "ok"
}
msg = truncateString(msg, 2048)
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsAlertEvaluatorJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
LastResult: &msg,
})
}
func (s *OpsAlertEvaluatorService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
if s == nil || s.opsRepo == nil || err == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
msg := truncateString(err.Error(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsAlertEvaluatorJobName,
LastRunAt: &runAt,
LastErrorAt: &now,
LastError: &msg,
LastDurationMs: &durMs,
})
}
func htmlEscape(s string) string {
replacer := strings.NewReplacer(
"&", "&amp;",
"<", "&lt;",
">", "&gt;",
`"`, "&quot;",
"'", "&#39;",
)
return replacer.Replace(s)
}
type slidingWindowLimiter struct {
mu sync.Mutex
limit int
window time.Duration
sent []time.Time
}
func newSlidingWindowLimiter(limit int, window time.Duration) *slidingWindowLimiter {
if window <= 0 {
window = time.Hour
}
return &slidingWindowLimiter{
limit: limit,
window: window,
sent: []time.Time{},
}
}
func (l *slidingWindowLimiter) SetLimit(limit int) {
l.mu.Lock()
defer l.mu.Unlock()
l.limit = limit
}
func (l *slidingWindowLimiter) Allow(now time.Time) bool {
l.mu.Lock()
defer l.mu.Unlock()
if l.limit <= 0 {
return true
}
cutoff := now.Add(-l.window)
keep := l.sent[:0]
for _, t := range l.sent {
if t.After(cutoff) {
keep = append(keep, t)
}
}
l.sent = keep
if len(l.sent) >= l.limit {
return false
}
l.sent = append(l.sent, now)
return true
}
// computeGroupAvailableRatio returns the available percentage for a group.
// Formula: (AvailableCount / TotalAccounts) * 100.
// Returns 0 when TotalAccounts is 0.
func computeGroupAvailableRatio(group *GroupAvailability) float64 {
if group == nil || group.TotalAccounts <= 0 {
return 0
}
return (float64(group.AvailableCount) / float64(group.TotalAccounts)) * 100
}
// countAccountsByCondition counts accounts that satisfy the given condition.
func countAccountsByCondition(accounts map[int64]*AccountAvailability, condition func(*AccountAvailability) bool) int64 {
if len(accounts) == 0 || condition == nil {
return 0
}
var count int64
for _, account := range accounts {
if account != nil && condition(account) {
count++
}
}
return count
}

View File

@@ -0,0 +1,210 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type stubOpsRepo struct {
OpsRepository
overview *OpsDashboardOverview
err error
}
func (s *stubOpsRepo) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
if s.err != nil {
return nil, s.err
}
if s.overview != nil {
return s.overview, nil
}
return &OpsDashboardOverview{}, nil
}
func TestComputeGroupAvailableRatio(t *testing.T) {
t.Parallel()
t.Run("正常情况: 10个账号, 8个可用 = 80%", func(t *testing.T) {
t.Parallel()
got := computeGroupAvailableRatio(&GroupAvailability{
TotalAccounts: 10,
AvailableCount: 8,
})
require.InDelta(t, 80.0, got, 0.0001)
})
t.Run("边界情况: TotalAccounts = 0 应返回 0", func(t *testing.T) {
t.Parallel()
got := computeGroupAvailableRatio(&GroupAvailability{
TotalAccounts: 0,
AvailableCount: 8,
})
require.Equal(t, 0.0, got)
})
t.Run("边界情况: AvailableCount = 0 应返回 0%", func(t *testing.T) {
t.Parallel()
got := computeGroupAvailableRatio(&GroupAvailability{
TotalAccounts: 10,
AvailableCount: 0,
})
require.Equal(t, 0.0, got)
})
}
func TestCountAccountsByCondition(t *testing.T) {
t.Parallel()
t.Run("测试限流账号统计: acc.IsRateLimited", func(t *testing.T) {
t.Parallel()
accounts := map[int64]*AccountAvailability{
1: {IsRateLimited: true},
2: {IsRateLimited: false},
3: {IsRateLimited: true},
}
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
return acc.IsRateLimited
})
require.Equal(t, int64(2), got)
})
t.Run("测试错误账号统计(排除临时不可调度): acc.HasError && acc.TempUnschedulableUntil == nil", func(t *testing.T) {
t.Parallel()
until := time.Now().UTC().Add(5 * time.Minute)
accounts := map[int64]*AccountAvailability{
1: {HasError: true},
2: {HasError: true, TempUnschedulableUntil: &until},
3: {HasError: false},
}
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
return acc.HasError && acc.TempUnschedulableUntil == nil
})
require.Equal(t, int64(1), got)
})
t.Run("边界情况: 空 map 应返回 0", func(t *testing.T) {
t.Parallel()
got := countAccountsByCondition(map[int64]*AccountAvailability{}, func(acc *AccountAvailability) bool {
return acc.IsRateLimited
})
require.Equal(t, int64(0), got)
})
}
func TestComputeRuleMetricNewIndicators(t *testing.T) {
t.Parallel()
groupID := int64(101)
platform := "openai"
availability := &OpsAccountAvailability{
Group: &GroupAvailability{
GroupID: groupID,
TotalAccounts: 10,
AvailableCount: 8,
},
Accounts: map[int64]*AccountAvailability{
1: {IsRateLimited: true},
2: {IsRateLimited: true},
3: {HasError: true},
4: {HasError: true, TempUnschedulableUntil: timePtr(time.Now().UTC().Add(2 * time.Minute))},
5: {HasError: false, IsRateLimited: false},
},
}
opsService := &OpsService{
getAccountAvailability: func(_ context.Context, _ string, _ *int64) (*OpsAccountAvailability, error) {
return availability, nil
},
}
svc := &OpsAlertEvaluatorService{
opsService: opsService,
opsRepo: &stubOpsRepo{overview: &OpsDashboardOverview{}},
}
start := time.Now().UTC().Add(-5 * time.Minute)
end := time.Now().UTC()
ctx := context.Background()
tests := []struct {
name string
metricType string
groupID *int64
wantValue float64
wantOK bool
}{
{
name: "group_available_accounts",
metricType: "group_available_accounts",
groupID: &groupID,
wantValue: 8,
wantOK: true,
},
{
name: "group_available_ratio",
metricType: "group_available_ratio",
groupID: &groupID,
wantValue: 80.0,
wantOK: true,
},
{
name: "account_rate_limited_count",
metricType: "account_rate_limited_count",
groupID: nil,
wantValue: 2,
wantOK: true,
},
{
name: "account_error_count",
metricType: "account_error_count",
groupID: nil,
wantValue: 1,
wantOK: true,
},
{
name: "group_available_accounts without group_id returns false",
metricType: "group_available_accounts",
groupID: nil,
wantValue: 0,
wantOK: false,
},
{
name: "group_available_ratio without group_id returns false",
metricType: "group_available_ratio",
groupID: nil,
wantValue: 0,
wantOK: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
rule := &OpsAlertRule{
MetricType: tt.metricType,
}
gotValue, gotOK := svc.computeRuleMetric(ctx, rule, nil, start, end, platform, tt.groupID)
require.Equal(t, tt.wantOK, gotOK)
if !tt.wantOK {
return
}
require.InDelta(t, tt.wantValue, gotValue, 0.0001)
})
}
}

View File

@@ -0,0 +1,95 @@
package service
import "time"
// Ops alert rule/event models.
//
// NOTE: These are admin-facing DTOs and intentionally keep JSON naming aligned
// with the existing ops dashboard frontend (backup style).
const (
OpsAlertStatusFiring = "firing"
OpsAlertStatusResolved = "resolved"
OpsAlertStatusManualResolved = "manual_resolved"
)
type OpsAlertRule struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Enabled bool `json:"enabled"`
Severity string `json:"severity"`
MetricType string `json:"metric_type"`
Operator string `json:"operator"`
Threshold float64 `json:"threshold"`
WindowMinutes int `json:"window_minutes"`
SustainedMinutes int `json:"sustained_minutes"`
CooldownMinutes int `json:"cooldown_minutes"`
NotifyEmail bool `json:"notify_email"`
Filters map[string]any `json:"filters,omitempty"`
LastTriggeredAt *time.Time `json:"last_triggered_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type OpsAlertEvent struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Severity string `json:"severity"`
Status string `json:"status"`
Title string `json:"title"`
Description string `json:"description"`
MetricValue *float64 `json:"metric_value,omitempty"`
ThresholdValue *float64 `json:"threshold_value,omitempty"`
Dimensions map[string]any `json:"dimensions,omitempty"`
FiredAt time.Time `json:"fired_at"`
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
EmailSent bool `json:"email_sent"`
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertSilence struct {
ID int64 `json:"id"`
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id,omitempty"`
Region *string `json:"region,omitempty"`
Until time.Time `json:"until"`
Reason string `json:"reason"`
CreatedBy *int64 `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
type OpsAlertEventFilter struct {
Limit int
// Cursor pagination (descending by fired_at, then id).
BeforeFiredAt *time.Time
BeforeID *int64
// Optional filters.
Status string
Severity string
EmailSent *bool
StartTime *time.Time
EndTime *time.Time
// Dimensions filters (best-effort).
Platform string
GroupID *int64
}

View File

@@ -0,0 +1,232 @@
package service
import (
"context"
"database/sql"
"errors"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *OpsService) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return []*OpsAlertRule{}, nil
}
return s.opsRepo.ListAlertRules(ctx)
}
func (s *OpsService) CreateAlertRule(ctx context.Context, rule *OpsAlertRule) (*OpsAlertRule, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if rule == nil {
return nil, infraerrors.BadRequest("INVALID_RULE", "invalid rule")
}
created, err := s.opsRepo.CreateAlertRule(ctx, rule)
if err != nil {
return nil, err
}
return created, nil
}
func (s *OpsService) UpdateAlertRule(ctx context.Context, rule *OpsAlertRule) (*OpsAlertRule, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if rule == nil || rule.ID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RULE", "invalid rule")
}
updated, err := s.opsRepo.UpdateAlertRule(ctx, rule)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.NotFound("OPS_ALERT_RULE_NOT_FOUND", "alert rule not found")
}
return nil, err
}
return updated, nil
}
func (s *OpsService) DeleteAlertRule(ctx context.Context, id int64) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if id <= 0 {
return infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if err := s.opsRepo.DeleteAlertRule(ctx, id); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.NotFound("OPS_ALERT_RULE_NOT_FOUND", "alert rule not found")
}
return err
}
return nil
}
func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return []*OpsAlertEvent{}, nil
}
return s.opsRepo.ListAlertEvents(ctx, filter)
}
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if eventID <= 0 {
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
}
return nil, err
}
if ev == nil {
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
}
return ev, nil
}
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if ruleID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
}
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if input == nil {
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
}
if input.RuleID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if strings.TrimSpace(input.Platform) == "" {
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
}
if input.Until.IsZero() {
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
}
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
if err != nil {
return nil, err
}
return created, nil
}
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return false, err
}
if s.opsRepo == nil {
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if ruleID <= 0 {
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
if strings.TrimSpace(platform) == "" {
return false, nil
}
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
}
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if ruleID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
}
return s.opsRepo.GetLatestAlertEvent(ctx, ruleID)
}
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if event == nil {
return nil, infraerrors.BadRequest("INVALID_EVENT", "invalid event")
}
created, err := s.opsRepo.CreateAlertEvent(ctx, event)
if err != nil {
return nil, err
}
return created, nil
}
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if eventID <= 0 {
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
status = strings.TrimSpace(status)
if status == "" {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
}
func (s *OpsService) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if eventID <= 0 {
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
return s.opsRepo.UpdateAlertEventEmailSent(ctx, eventID, emailSent)
}

View File

@@ -0,0 +1,367 @@
package service
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/robfig/cron/v3"
)
const (
opsCleanupJobName = "ops_cleanup"
opsCleanupLeaderLockKeyDefault = "ops:cleanup:leader"
opsCleanupLeaderLockTTLDefault = 30 * time.Minute
)
var opsCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
var opsCleanupReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
// OpsCleanupService periodically deletes old ops data to prevent unbounded DB growth.
//
// - Scheduling: 5-field cron spec (minute hour dom month dow).
// - Multi-instance: best-effort Redis leader lock so only one node runs cleanup.
// - Safety: deletes in batches to avoid long transactions.
type OpsCleanupService struct {
opsRepo OpsRepository
db *sql.DB
redisClient *redis.Client
cfg *config.Config
instanceID string
cron *cron.Cron
startOnce sync.Once
stopOnce sync.Once
warnNoRedisOnce sync.Once
}
func NewOpsCleanupService(
opsRepo OpsRepository,
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
) *OpsCleanupService {
return &OpsCleanupService{
opsRepo: opsRepo,
db: db,
redisClient: redisClient,
cfg: cfg,
instanceID: uuid.NewString(),
}
}
func (s *OpsCleanupService) Start() {
if s == nil {
return
}
if s.cfg != nil && !s.cfg.Ops.Enabled {
return
}
if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled {
log.Printf("[OpsCleanup] not started (disabled)")
return
}
if s.opsRepo == nil || s.db == nil {
log.Printf("[OpsCleanup] not started (missing deps)")
return
}
s.startOnce.Do(func() {
schedule := "0 2 * * *"
if s.cfg != nil && strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule) != "" {
schedule = strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule)
}
loc := time.Local
if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" {
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
loc = parsed
}
}
c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc))
_, err := c.AddFunc(schedule, func() { s.runScheduled() })
if err != nil {
log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err)
return
}
s.cron = c
s.cron.Start()
log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
})
}
func (s *OpsCleanupService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.cron != nil {
ctx := s.cron.Stop()
select {
case <-ctx.Done():
case <-time.After(3 * time.Second):
log.Printf("[OpsCleanup] cron stop timed out")
}
}
})
}
func (s *OpsCleanupService) runScheduled() {
if s == nil || s.db == nil || s.opsRepo == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
release, ok := s.tryAcquireLeaderLock(ctx)
if !ok {
return
}
if release != nil {
defer release()
}
startedAt := time.Now().UTC()
runAt := startedAt
counts, err := s.runCleanupOnce(ctx)
if err != nil {
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
log.Printf("[OpsCleanup] cleanup failed: %v", err)
return
}
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
}
type opsCleanupDeletedCounts struct {
errorLogs int64
retryAttempts int64
alertEvents int64
systemMetrics int64
hourlyPreagg int64
dailyPreagg int64
}
func (c opsCleanupDeletedCounts) String() string {
return fmt.Sprintf(
"error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
c.errorLogs,
c.retryAttempts,
c.alertEvents,
c.systemMetrics,
c.hourlyPreagg,
c.dailyPreagg,
)
}
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
out := opsCleanupDeletedCounts{}
if s == nil || s.db == nil || s.cfg == nil {
return out, nil
}
batchSize := 5000
now := time.Now().UTC()
// Error-like tables: error logs / retry attempts / alert events.
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
cutoff := now.AddDate(0, 0, -days)
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
if err != nil {
return out, err
}
out.errorLogs = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
if err != nil {
return out, err
}
out.retryAttempts = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
if err != nil {
return out, err
}
out.alertEvents = n
}
// Minute-level metrics snapshots.
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
cutoff := now.AddDate(0, 0, -days)
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
if err != nil {
return out, err
}
out.systemMetrics = n
}
// Pre-aggregation tables (hourly/daily).
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
cutoff := now.AddDate(0, 0, -days)
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
if err != nil {
return out, err
}
out.hourlyPreagg = n
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
if err != nil {
return out, err
}
out.dailyPreagg = n
}
return out, nil
}
func deleteOldRowsByID(
ctx context.Context,
db *sql.DB,
table string,
timeColumn string,
cutoff time.Time,
batchSize int,
castCutoffToDate bool,
) (int64, error) {
if db == nil {
return 0, nil
}
if batchSize <= 0 {
batchSize = 5000
}
where := fmt.Sprintf("%s < $1", timeColumn)
if castCutoffToDate {
where = fmt.Sprintf("%s < $1::date", timeColumn)
}
q := fmt.Sprintf(`
WITH batch AS (
SELECT id FROM %s
WHERE %s
ORDER BY id
LIMIT $2
)
DELETE FROM %s
WHERE id IN (SELECT id FROM batch)
`, table, where, table)
var total int64
for {
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
if err != nil {
// If ops tables aren't present yet (partial deployments), treat as no-op.
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
return total, nil
}
return total, err
}
affected, err := res.RowsAffected()
if err != nil {
return total, err
}
total += affected
if affected == 0 {
break
}
}
return total, nil
}
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if s == nil {
return nil, false
}
// In simple run mode, assume single instance.
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
return nil, true
}
key := opsCleanupLeaderLockKeyDefault
ttl := opsCleanupLeaderLockTTLDefault
// Prefer Redis leader lock when available, but avoid stampeding the DB when Redis is flaky by
// falling back to a DB advisory lock.
if s.redisClient != nil {
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
if err == nil {
if !ok {
return nil, false
}
return func() {
_, _ = opsCleanupReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
}, true
}
// Redis error: fall back to DB advisory lock.
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err)
})
} else {
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsCleanup] redis not configured; using DB advisory lock")
})
}
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
if !ok {
return nil, false
}
return release, true
}
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, counts opsCleanupDeletedCounts) {
if s == nil || s.opsRepo == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
result := truncateString(counts.String(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsCleanupJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
LastResult: &result,
})
}
func (s *OpsCleanupService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
if s == nil || s.opsRepo == nil || err == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
msg := truncateString(err.Error(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsCleanupJobName,
LastRunAt: &runAt,
LastErrorAt: &now,
LastError: &msg,
LastDurationMs: &durMs,
})
}

View File

@@ -0,0 +1,257 @@
package service
import (
"context"
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
opsAccountsPageSize = 100
opsConcurrencyBatchChunkSize = 200
)
func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter string) ([]Account, error) {
if s == nil || s.accountRepo == nil {
return []Account{}, nil
}
out := make([]Account, 0, 128)
page := 1
for {
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
}, platformFilter, "", "", "")
if err != nil {
return nil, err
}
if len(accounts) == 0 {
break
}
out = append(out, accounts...)
if pageInfo != nil && int64(len(out)) >= pageInfo.Total {
break
}
if len(accounts) < opsAccountsPageSize {
break
}
page++
if page > 10_000 {
log.Printf("[Ops] listAllAccountsForOps: aborting after too many pages (platform=%q)", platformFilter)
break
}
}
return out, nil
}
func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts []Account) map[int64]*AccountLoadInfo {
if s == nil || s.concurrencyService == nil {
return map[int64]*AccountLoadInfo{}
}
if len(accounts) == 0 {
return map[int64]*AccountLoadInfo{}
}
// De-duplicate IDs (and keep the max concurrency to avoid under-reporting).
unique := make(map[int64]int, len(accounts))
for _, acc := range accounts {
if acc.ID <= 0 {
continue
}
if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev {
unique[acc.ID] = acc.Concurrency
}
}
batch := make([]AccountWithConcurrency, 0, len(unique))
for id, maxConc := range unique {
batch = append(batch, AccountWithConcurrency{
ID: id,
MaxConcurrency: maxConc,
})
}
out := make(map[int64]*AccountLoadInfo, len(batch))
for i := 0; i < len(batch); i += opsConcurrencyBatchChunkSize {
end := i + opsConcurrencyBatchChunkSize
if end > len(batch) {
end = len(batch)
}
part, err := s.concurrencyService.GetAccountsLoadBatch(ctx, batch[i:end])
if err != nil {
// Best-effort: return zeros rather than failing the ops UI.
log.Printf("[Ops] GetAccountsLoadBatch failed: %v", err)
continue
}
for k, v := range part {
out[k] = v
}
}
return out
}
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
//
// Optional filters:
// - platformFilter: only include accounts in that platform (best-effort reduces DB load)
// - groupIDFilter: only include accounts that belong to that group
func (s *OpsService) GetConcurrencyStats(
ctx context.Context,
platformFilter string,
groupIDFilter *int64,
) (map[string]*PlatformConcurrencyInfo, map[int64]*GroupConcurrencyInfo, map[int64]*AccountConcurrencyInfo, *time.Time, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, nil, nil, nil, err
}
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
if err != nil {
return nil, nil, nil, nil, err
}
collectedAt := time.Now()
loadMap := s.getAccountsLoadMapBestEffort(ctx, accounts)
platform := make(map[string]*PlatformConcurrencyInfo)
group := make(map[int64]*GroupConcurrencyInfo)
account := make(map[int64]*AccountConcurrencyInfo)
for _, acc := range accounts {
if acc.ID <= 0 {
continue
}
var matchedGroup *Group
if groupIDFilter != nil && *groupIDFilter > 0 {
for _, grp := range acc.Groups {
if grp == nil || grp.ID <= 0 {
continue
}
if grp.ID == *groupIDFilter {
matchedGroup = grp
break
}
}
// Group filter provided: skip accounts not in that group.
if matchedGroup == nil {
continue
}
}
load := loadMap[acc.ID]
currentInUse := int64(0)
waiting := int64(0)
if load != nil {
currentInUse = int64(load.CurrentConcurrency)
waiting = int64(load.WaitingCount)
}
// Account-level view picks one display group (the first group).
displayGroupID := int64(0)
displayGroupName := ""
if matchedGroup != nil {
displayGroupID = matchedGroup.ID
displayGroupName = matchedGroup.Name
} else if len(acc.Groups) > 0 && acc.Groups[0] != nil {
displayGroupID = acc.Groups[0].ID
displayGroupName = acc.Groups[0].Name
}
if _, ok := account[acc.ID]; !ok {
info := &AccountConcurrencyInfo{
AccountID: acc.ID,
AccountName: acc.Name,
Platform: acc.Platform,
GroupID: displayGroupID,
GroupName: displayGroupName,
CurrentInUse: currentInUse,
MaxCapacity: int64(acc.Concurrency),
WaitingInQueue: waiting,
}
if info.MaxCapacity > 0 {
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
}
account[acc.ID] = info
}
// Platform aggregation.
if acc.Platform != "" {
if _, ok := platform[acc.Platform]; !ok {
platform[acc.Platform] = &PlatformConcurrencyInfo{
Platform: acc.Platform,
}
}
p := platform[acc.Platform]
p.MaxCapacity += int64(acc.Concurrency)
p.CurrentInUse += currentInUse
p.WaitingInQueue += waiting
}
// Group aggregation (one account may contribute to multiple groups).
if matchedGroup != nil {
grp := matchedGroup
if _, ok := group[grp.ID]; !ok {
group[grp.ID] = &GroupConcurrencyInfo{
GroupID: grp.ID,
GroupName: grp.Name,
Platform: grp.Platform,
}
}
g := group[grp.ID]
if g.GroupName == "" && grp.Name != "" {
g.GroupName = grp.Name
}
if g.Platform != "" && grp.Platform != "" && g.Platform != grp.Platform {
// Groups are expected to be platform-scoped. If mismatch is observed, avoid misleading labels.
g.Platform = ""
}
g.MaxCapacity += int64(acc.Concurrency)
g.CurrentInUse += currentInUse
g.WaitingInQueue += waiting
} else {
for _, grp := range acc.Groups {
if grp == nil || grp.ID <= 0 {
continue
}
if _, ok := group[grp.ID]; !ok {
group[grp.ID] = &GroupConcurrencyInfo{
GroupID: grp.ID,
GroupName: grp.Name,
Platform: grp.Platform,
}
}
g := group[grp.ID]
if g.GroupName == "" && grp.Name != "" {
g.GroupName = grp.Name
}
if g.Platform != "" && grp.Platform != "" && g.Platform != grp.Platform {
// Groups are expected to be platform-scoped. If mismatch is observed, avoid misleading labels.
g.Platform = ""
}
g.MaxCapacity += int64(acc.Concurrency)
g.CurrentInUse += currentInUse
g.WaitingInQueue += waiting
}
}
}
for _, info := range platform {
if info.MaxCapacity > 0 {
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
}
}
for _, info := range group {
if info.MaxCapacity > 0 {
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
}
}
return platform, group, account, &collectedAt, nil
}

View File

@@ -0,0 +1,90 @@
package service
import (
"context"
"database/sql"
"errors"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if filter == nil {
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
}
if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
}
// Resolve query mode (requested via query param, or DB default).
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
overview, err := s.opsRepo.GetDashboardOverview(ctx, filter)
if err != nil {
if errors.Is(err, ErrOpsPreaggregatedNotPopulated) {
return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet")
}
return nil, err
}
// Best-effort system health + jobs; dashboard metrics should still render if these are missing.
if metrics, err := s.opsRepo.GetLatestSystemMetrics(ctx, 1); err == nil {
// Attach config-derived limits so the UI can show "current / max" for connection pools.
// These are best-effort and should never block the dashboard rendering.
if s != nil && s.cfg != nil {
if s.cfg.Database.MaxOpenConns > 0 {
metrics.DBMaxOpenConns = intPtr(s.cfg.Database.MaxOpenConns)
}
if s.cfg.Redis.PoolSize > 0 {
metrics.RedisPoolSize = intPtr(s.cfg.Redis.PoolSize)
}
}
overview.SystemMetrics = metrics
} else if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Printf("[Ops] GetLatestSystemMetrics failed: %v", err)
}
if heartbeats, err := s.opsRepo.ListJobHeartbeats(ctx); err == nil {
overview.JobHeartbeats = heartbeats
} else {
log.Printf("[Ops] ListJobHeartbeats failed: %v", err)
}
overview.HealthScore = computeDashboardHealthScore(time.Now().UTC(), overview)
return overview, nil
}
func (s *OpsService) resolveOpsQueryMode(ctx context.Context, requested OpsQueryMode) OpsQueryMode {
if requested.IsValid() {
// Allow "auto" to be disabled via config until preagg is proven stable in production.
// Forced `preagg` via query param still works.
if requested == OpsQueryModeAuto && s != nil && s.cfg != nil && !s.cfg.Ops.UsePreaggregatedTables {
return OpsQueryModeRaw
}
return requested
}
mode := OpsQueryModeAuto
if s != nil && s.settingRepo != nil {
if raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsQueryModeDefault); err == nil {
mode = ParseOpsQueryMode(raw)
}
}
if mode == OpsQueryModeAuto && s != nil && s.cfg != nil && !s.cfg.Ops.UsePreaggregatedTables {
return OpsQueryModeRaw
}
return mode
}

View File

@@ -0,0 +1,87 @@
package service
import "time"
type OpsDashboardFilter struct {
StartTime time.Time
EndTime time.Time
Platform string
GroupID *int64
// QueryMode controls whether dashboard queries should use raw logs or pre-aggregated tables.
// Expected values: auto/raw/preagg (see OpsQueryMode).
QueryMode OpsQueryMode
}
type OpsRateSummary struct {
Current float64 `json:"current"`
Peak float64 `json:"peak"`
Avg float64 `json:"avg"`
}
type OpsPercentiles struct {
P50 *int `json:"p50_ms"`
P90 *int `json:"p90_ms"`
P95 *int `json:"p95_ms"`
P99 *int `json:"p99_ms"`
Avg *int `json:"avg_ms"`
Max *int `json:"max_ms"`
}
type OpsDashboardOverview struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
// HealthScore is a backend-computed overall health score (0-100).
// It is derived from the monitored metrics in this overview, plus best-effort system metrics/job heartbeats.
HealthScore int `json:"health_score"`
// Latest system-level snapshot (window=1m, global).
SystemMetrics *OpsSystemMetricsSnapshot `json:"system_metrics"`
// Background jobs health (heartbeats).
JobHeartbeats []*OpsJobHeartbeat `json:"job_heartbeats"`
SuccessCount int64 `json:"success_count"`
ErrorCountTotal int64 `json:"error_count_total"`
BusinessLimitedCount int64 `json:"business_limited_count"`
ErrorCountSLA int64 `json:"error_count_sla"`
RequestCountTotal int64 `json:"request_count_total"`
RequestCountSLA int64 `json:"request_count_sla"`
TokenConsumed int64 `json:"token_consumed"`
SLA float64 `json:"sla"`
ErrorRate float64 `json:"error_rate"`
UpstreamErrorRate float64 `json:"upstream_error_rate"`
UpstreamErrorCountExcl429529 int64 `json:"upstream_error_count_excl_429_529"`
Upstream429Count int64 `json:"upstream_429_count"`
Upstream529Count int64 `json:"upstream_529_count"`
QPS OpsRateSummary `json:"qps"`
TPS OpsRateSummary `json:"tps"`
Duration OpsPercentiles `json:"duration"`
TTFT OpsPercentiles `json:"ttft"`
}
type OpsLatencyHistogramBucket struct {
Range string `json:"range"`
Count int64 `json:"count"`
}
// OpsLatencyHistogramResponse is a coarse latency distribution histogram (success requests only).
// It is used by the Ops dashboard to quickly identify tail latency regressions.
type OpsLatencyHistogramResponse struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
TotalRequests int64 `json:"total_requests"`
Buckets []*OpsLatencyHistogramBucket `json:"buckets"`
}

View File

@@ -0,0 +1,45 @@
package service
import (
"context"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if filter == nil {
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
}
if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
}
return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds)
}
func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if filter == nil {
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
}
if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
}
return s.opsRepo.GetErrorDistribution(ctx, filter)
}

View File

@@ -0,0 +1,143 @@
package service
import (
"math"
"time"
)
// computeDashboardHealthScore computes a 0-100 health score from the metrics returned by the dashboard overview.
//
// Design goals:
// - Backend-owned scoring (UI only displays).
// - Layered scoring: Business Health (70%) + Infrastructure Health (30%)
// - Avoids double-counting (e.g., DB failure affects both infra and business metrics)
// - Conservative + stable: penalize clear degradations; avoid overreacting to missing/idle data.
func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview) int {
if overview == nil {
return 0
}
// Idle/no-data: avoid showing a "bad" score when there is no traffic.
// UI can still render a gray/idle state based on QPS + error rate.
if overview.RequestCountSLA <= 0 && overview.RequestCountTotal <= 0 && overview.ErrorCountTotal <= 0 {
return 100
}
businessHealth := computeBusinessHealth(overview)
infraHealth := computeInfraHealth(now, overview)
// Weighted combination: 70% business + 30% infrastructure
score := businessHealth*0.7 + infraHealth*0.3
return int(math.Round(clampFloat64(score, 0, 100)))
}
// computeBusinessHealth calculates business health score (0-100)
// Components: Error Rate (50%) + TTFT (50%)
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
// Error rate score: 1% → 100, 10% → 0 (linear)
// Combines request errors and upstream errors
errorScore := 100.0
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
if combinedErrorPct > 1.0 {
if combinedErrorPct <= 10.0 {
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
} else {
errorScore = 0
}
}
// TTFT score: 1s → 100, 3s → 0 (linear)
// Time to first token is critical for user experience
ttftScore := 100.0
if overview.TTFT.P99 != nil {
p99 := float64(*overview.TTFT.P99)
if p99 > 1000 {
if p99 <= 3000 {
ttftScore = (3000 - p99) / 2000 * 100
} else {
ttftScore = 0
}
}
}
// Weighted combination: 50% error rate + 50% TTFT
return errorScore*0.5 + ttftScore*0.5
}
// computeInfraHealth calculates infrastructure health score (0-100)
// Components: Storage (40%) + Compute Resources (30%) + Background Jobs (30%)
func computeInfraHealth(now time.Time, overview *OpsDashboardOverview) float64 {
// Storage score: DB critical, Redis less critical
storageScore := 100.0
if overview.SystemMetrics != nil {
if overview.SystemMetrics.DBOK != nil && !*overview.SystemMetrics.DBOK {
storageScore = 0 // DB failure is critical
} else if overview.SystemMetrics.RedisOK != nil && !*overview.SystemMetrics.RedisOK {
storageScore = 50 // Redis failure is degraded but not critical
}
}
// Compute resources score: CPU + Memory
computeScore := 100.0
if overview.SystemMetrics != nil {
cpuScore := 100.0
if overview.SystemMetrics.CPUUsagePercent != nil {
cpuPct := clampFloat64(*overview.SystemMetrics.CPUUsagePercent, 0, 100)
if cpuPct > 80 {
if cpuPct <= 100 {
cpuScore = (100 - cpuPct) / 20 * 100
} else {
cpuScore = 0
}
}
}
memScore := 100.0
if overview.SystemMetrics.MemoryUsagePercent != nil {
memPct := clampFloat64(*overview.SystemMetrics.MemoryUsagePercent, 0, 100)
if memPct > 85 {
if memPct <= 100 {
memScore = (100 - memPct) / 15 * 100
} else {
memScore = 0
}
}
}
computeScore = (cpuScore + memScore) / 2
}
// Background jobs score
jobScore := 100.0
failedJobs := 0
totalJobs := 0
for _, hb := range overview.JobHeartbeats {
if hb == nil {
continue
}
totalJobs++
if hb.LastErrorAt != nil && (hb.LastSuccessAt == nil || hb.LastErrorAt.After(*hb.LastSuccessAt)) {
failedJobs++
} else if hb.LastSuccessAt != nil && now.Sub(*hb.LastSuccessAt) > 15*time.Minute {
failedJobs++
}
}
if totalJobs > 0 && failedJobs > 0 {
jobScore = (1 - float64(failedJobs)/float64(totalJobs)) * 100
}
// Weighted combination
return storageScore*0.4 + computeScore*0.3 + jobScore*0.3
}
func clampFloat64(v float64, min float64, max float64) float64 {
if v < min {
return min
}
if v > max {
return max
}
return v
}

View File

@@ -0,0 +1,442 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestComputeDashboardHealthScore_IdleReturns100(t *testing.T) {
t.Parallel()
score := computeDashboardHealthScore(time.Now().UTC(), &OpsDashboardOverview{})
require.Equal(t, 100, score)
}
func TestComputeDashboardHealthScore_DegradesOnBadSignals(t *testing.T) {
t.Parallel()
ov := &OpsDashboardOverview{
RequestCountTotal: 100,
RequestCountSLA: 100,
SuccessCount: 90,
ErrorCountTotal: 10,
ErrorCountSLA: 10,
SLA: 0.90,
ErrorRate: 0.10,
UpstreamErrorRate: 0.08,
Duration: OpsPercentiles{P99: intPtr(20_000)},
TTFT: OpsPercentiles{P99: intPtr(2_000)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(false),
RedisOK: boolPtr(false),
CPUUsagePercent: float64Ptr(98.0),
MemoryUsagePercent: float64Ptr(97.0),
DBConnWaiting: intPtr(3),
ConcurrencyQueueDepth: intPtr(10),
},
JobHeartbeats: []*OpsJobHeartbeat{
{
JobName: "job-a",
LastErrorAt: timePtr(time.Now().UTC().Add(-1 * time.Minute)),
LastError: stringPtr("boom"),
},
},
}
score := computeDashboardHealthScore(time.Now().UTC(), ov)
require.Less(t, score, 80)
require.GreaterOrEqual(t, score, 0)
}
func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
t.Parallel()
tests := []struct {
name string
overview *OpsDashboardOverview
wantMin int
wantMax int
}{
{
name: "nil overview returns 0",
overview: nil,
wantMin: 0,
wantMax: 0,
},
{
name: "perfect health",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 1.0,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
TTFT: OpsPercentiles{P99: intPtr(100)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 100,
wantMax: 100,
},
{
name: "good health - SLA 99.8%",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.998,
ErrorRate: 0.003,
UpstreamErrorRate: 0.001,
Duration: OpsPercentiles{P99: intPtr(800)},
TTFT: OpsPercentiles{P99: intPtr(200)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(50),
MemoryUsagePercent: float64Ptr(60),
},
},
wantMin: 95,
wantMax: 100,
},
{
name: "medium health - SLA 96%",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.96,
ErrorRate: 0.02,
UpstreamErrorRate: 0.01,
Duration: OpsPercentiles{P99: intPtr(3000)},
TTFT: OpsPercentiles{P99: intPtr(600)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(70),
MemoryUsagePercent: float64Ptr(75),
},
},
wantMin: 96,
wantMax: 97,
},
{
name: "DB failure",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.995,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(false),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 70,
wantMax: 90,
},
{
name: "Redis failure",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.995,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(false),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 85,
wantMax: 95,
},
{
name: "high CPU usage",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.995,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(95),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 85,
wantMax: 100,
},
{
name: "combined failures - business degraded + infra healthy",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.90,
ErrorRate: 0.05,
UpstreamErrorRate: 0.02,
Duration: OpsPercentiles{P99: intPtr(10000)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(20),
MemoryUsagePercent: float64Ptr(30),
},
},
wantMin: 84,
wantMax: 85,
},
{
name: "combined failures - business healthy + infra degraded",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
RequestCountSLA: 1000,
SLA: 0.998,
ErrorRate: 0.001,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(600)},
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(false),
RedisOK: boolPtr(false),
CPUUsagePercent: float64Ptr(95),
MemoryUsagePercent: float64Ptr(95),
},
},
wantMin: 70,
wantMax: 90,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
score := computeDashboardHealthScore(time.Now().UTC(), tt.overview)
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %d", tt.wantMin)
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %d", tt.wantMax)
require.GreaterOrEqual(t, score, 0, "score must be >= 0")
require.LessOrEqual(t, score, 100, "score must be <= 100")
})
}
}
func TestComputeBusinessHealth(t *testing.T) {
t.Parallel()
tests := []struct {
name string
overview *OpsDashboardOverview
wantMin float64
wantMax float64
}{
{
name: "perfect metrics",
overview: &OpsDashboardOverview{
SLA: 1.0,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 100,
wantMax: 100,
},
{
name: "SLA boundary 99.5%",
overview: &OpsDashboardOverview{
SLA: 0.995,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 100,
wantMax: 100,
},
{
name: "SLA boundary 95%",
overview: &OpsDashboardOverview{
SLA: 0.95,
ErrorRate: 0,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 100,
wantMax: 100,
},
{
name: "error rate boundary 1%",
overview: &OpsDashboardOverview{
SLA: 0.99,
ErrorRate: 0.01,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 100,
wantMax: 100,
},
{
name: "error rate 5%",
overview: &OpsDashboardOverview{
SLA: 0.95,
ErrorRate: 0.05,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 77,
wantMax: 78,
},
{
name: "TTFT boundary 2s",
overview: &OpsDashboardOverview{
SLA: 0.99,
ErrorRate: 0,
UpstreamErrorRate: 0,
TTFT: OpsPercentiles{P99: intPtr(2000)},
},
wantMin: 75,
wantMax: 75,
},
{
name: "upstream error dominates",
overview: &OpsDashboardOverview{
SLA: 0.995,
ErrorRate: 0.001,
UpstreamErrorRate: 0.03,
Duration: OpsPercentiles{P99: intPtr(500)},
},
wantMin: 88,
wantMax: 90,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
score := computeBusinessHealth(tt.overview)
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %.1f", tt.wantMin)
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %.1f", tt.wantMax)
require.GreaterOrEqual(t, score, 0.0, "score must be >= 0")
require.LessOrEqual(t, score, 100.0, "score must be <= 100")
})
}
}
func TestComputeInfraHealth(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
tests := []struct {
name string
overview *OpsDashboardOverview
wantMin float64
wantMax float64
}{
{
name: "all infrastructure healthy",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 100,
wantMax: 100,
},
{
name: "DB down",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(false),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 50,
wantMax: 70,
},
{
name: "Redis down",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(false),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 80,
wantMax: 95,
},
{
name: "CPU at 90%",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(90),
MemoryUsagePercent: float64Ptr(40),
},
},
wantMin: 85,
wantMax: 95,
},
{
name: "failed background job",
overview: &OpsDashboardOverview{
RequestCountTotal: 1000,
SystemMetrics: &OpsSystemMetricsSnapshot{
DBOK: boolPtr(true),
RedisOK: boolPtr(true),
CPUUsagePercent: float64Ptr(30),
MemoryUsagePercent: float64Ptr(40),
},
JobHeartbeats: []*OpsJobHeartbeat{
{
JobName: "test-job",
LastErrorAt: &now,
},
},
},
wantMin: 70,
wantMax: 90,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
score := computeInfraHealth(now, tt.overview)
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %.1f", tt.wantMin)
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %.1f", tt.wantMax)
require.GreaterOrEqual(t, score, 0.0, "score must be >= 0")
require.LessOrEqual(t, score, 100.0, "score must be <= 100")
})
}
}
func timePtr(v time.Time) *time.Time { return &v }
func stringPtr(v string) *string { return &v }

View File

@@ -0,0 +1,26 @@
package service
import (
"context"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if filter == nil {
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
}
if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
}
return s.opsRepo.GetLatencyHistogram(ctx, filter)
}

View File

@@ -0,0 +1,920 @@
package service
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"math"
"os"
"runtime"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/mem"
)
const (
opsMetricsCollectorJobName = "ops_metrics_collector"
opsMetricsCollectorMinInterval = 60 * time.Second
opsMetricsCollectorMaxInterval = 1 * time.Hour
opsMetricsCollectorTimeout = 10 * time.Second
opsMetricsCollectorLeaderLockKey = "ops:metrics:collector:leader"
opsMetricsCollectorLeaderLockTTL = 90 * time.Second
opsMetricsCollectorHeartbeatTimeout = 2 * time.Second
bytesPerMB = 1024 * 1024
)
var opsMetricsCollectorAdvisoryLockID = hashAdvisoryLockID(opsMetricsCollectorLeaderLockKey)
type OpsMetricsCollector struct {
opsRepo OpsRepository
settingRepo SettingRepository
cfg *config.Config
accountRepo AccountRepository
concurrencyService *ConcurrencyService
db *sql.DB
redisClient *redis.Client
instanceID string
lastCgroupCPUUsageNanos uint64
lastCgroupCPUSampleAt time.Time
stopCh chan struct{}
startOnce sync.Once
stopOnce sync.Once
skipLogMu sync.Mutex
skipLogAt time.Time
}
func NewOpsMetricsCollector(
opsRepo OpsRepository,
settingRepo SettingRepository,
accountRepo AccountRepository,
concurrencyService *ConcurrencyService,
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
) *OpsMetricsCollector {
return &OpsMetricsCollector{
opsRepo: opsRepo,
settingRepo: settingRepo,
cfg: cfg,
accountRepo: accountRepo,
concurrencyService: concurrencyService,
db: db,
redisClient: redisClient,
instanceID: uuid.NewString(),
}
}
func (c *OpsMetricsCollector) Start() {
if c == nil {
return
}
c.startOnce.Do(func() {
if c.stopCh == nil {
c.stopCh = make(chan struct{})
}
go c.run()
})
}
func (c *OpsMetricsCollector) Stop() {
if c == nil {
return
}
c.stopOnce.Do(func() {
if c.stopCh != nil {
close(c.stopCh)
}
})
}
func (c *OpsMetricsCollector) run() {
// First run immediately so the dashboard has data soon after startup.
c.collectOnce()
for {
interval := c.getInterval()
timer := time.NewTimer(interval)
select {
case <-timer.C:
c.collectOnce()
case <-c.stopCh:
timer.Stop()
return
}
}
}
func (c *OpsMetricsCollector) getInterval() time.Duration {
interval := opsMetricsCollectorMinInterval
if c.settingRepo == nil {
return interval
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
raw, err := c.settingRepo.GetValue(ctx, SettingKeyOpsMetricsIntervalSeconds)
if err != nil {
return interval
}
raw = strings.TrimSpace(raw)
if raw == "" {
return interval
}
seconds, err := strconv.Atoi(raw)
if err != nil {
return interval
}
if seconds < int(opsMetricsCollectorMinInterval.Seconds()) {
seconds = int(opsMetricsCollectorMinInterval.Seconds())
}
if seconds > int(opsMetricsCollectorMaxInterval.Seconds()) {
seconds = int(opsMetricsCollectorMaxInterval.Seconds())
}
return time.Duration(seconds) * time.Second
}
func (c *OpsMetricsCollector) collectOnce() {
if c == nil {
return
}
if c.cfg != nil && !c.cfg.Ops.Enabled {
return
}
if c.opsRepo == nil {
return
}
if c.db == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectorTimeout)
defer cancel()
if !c.isMonitoringEnabled(ctx) {
return
}
release, ok := c.tryAcquireLeaderLock(ctx)
if !ok {
return
}
if release != nil {
defer release()
}
startedAt := time.Now().UTC()
err := c.collectAndPersist(ctx)
finishedAt := time.Now().UTC()
durationMs := finishedAt.Sub(startedAt).Milliseconds()
dur := durationMs
runAt := startedAt
if err != nil {
msg := truncateString(err.Error(), 2048)
errAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), opsMetricsCollectorHeartbeatTimeout)
defer hbCancel()
_ = c.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsMetricsCollectorJobName,
LastRunAt: &runAt,
LastErrorAt: &errAt,
LastError: &msg,
LastDurationMs: &dur,
})
log.Printf("[OpsMetricsCollector] collect failed: %v", err)
return
}
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), opsMetricsCollectorHeartbeatTimeout)
defer hbCancel()
_ = c.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsMetricsCollectorJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
})
}
func (c *OpsMetricsCollector) isMonitoringEnabled(ctx context.Context) bool {
if c == nil {
return false
}
if c.cfg != nil && !c.cfg.Ops.Enabled {
return false
}
if c.settingRepo == nil {
return true
}
if ctx == nil {
ctx = context.Background()
}
value, err := c.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return true
}
// Fail-open: collector should not become a hard dependency.
return true
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
return false
default:
return true
}
}
func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
// Align to stable minute boundaries to avoid partial buckets and to maximize cache hits.
now := time.Now().UTC()
windowEnd := now.Truncate(time.Minute)
windowStart := windowEnd.Add(-1 * time.Minute)
sys, err := c.collectSystemStats(ctx)
if err != nil {
// Continue; system stats are best-effort.
log.Printf("[OpsMetricsCollector] system stats error: %v", err)
}
dbOK := c.checkDB(ctx)
redisOK := c.checkRedis(ctx)
active, idle := c.dbPoolStats()
redisTotal, redisIdle, redisStatsOK := c.redisPoolStats()
successCount, tokenConsumed, err := c.queryUsageCounts(ctx, windowStart, windowEnd)
if err != nil {
return fmt.Errorf("query usage counts: %w", err)
}
duration, ttft, err := c.queryUsageLatency(ctx, windowStart, windowEnd)
if err != nil {
return fmt.Errorf("query usage latency: %w", err)
}
errorTotal, businessLimited, errorSLA, upstreamExcl, upstream429, upstream529, err := c.queryErrorCounts(ctx, windowStart, windowEnd)
if err != nil {
return fmt.Errorf("query error counts: %w", err)
}
windowSeconds := windowEnd.Sub(windowStart).Seconds()
if windowSeconds <= 0 {
windowSeconds = 60
}
requestTotal := successCount + errorTotal
qps := float64(requestTotal) / windowSeconds
tps := float64(tokenConsumed) / windowSeconds
goroutines := runtime.NumGoroutine()
concurrencyQueueDepth := c.collectConcurrencyQueueDepth(ctx)
input := &OpsInsertSystemMetricsInput{
CreatedAt: windowEnd,
WindowMinutes: 1,
SuccessCount: successCount,
ErrorCountTotal: errorTotal,
BusinessLimitedCount: businessLimited,
ErrorCountSLA: errorSLA,
UpstreamErrorCountExcl429529: upstreamExcl,
Upstream429Count: upstream429,
Upstream529Count: upstream529,
TokenConsumed: tokenConsumed,
QPS: float64Ptr(roundTo1DP(qps)),
TPS: float64Ptr(roundTo1DP(tps)),
DurationP50Ms: duration.p50,
DurationP90Ms: duration.p90,
DurationP95Ms: duration.p95,
DurationP99Ms: duration.p99,
DurationAvgMs: duration.avg,
DurationMaxMs: duration.max,
TTFTP50Ms: ttft.p50,
TTFTP90Ms: ttft.p90,
TTFTP95Ms: ttft.p95,
TTFTP99Ms: ttft.p99,
TTFTAvgMs: ttft.avg,
TTFTMaxMs: ttft.max,
CPUUsagePercent: sys.cpuUsagePercent,
MemoryUsedMB: sys.memoryUsedMB,
MemoryTotalMB: sys.memoryTotalMB,
MemoryUsagePercent: sys.memoryUsagePercent,
DBOK: boolPtr(dbOK),
RedisOK: boolPtr(redisOK),
RedisConnTotal: func() *int {
if !redisStatsOK {
return nil
}
return intPtr(redisTotal)
}(),
RedisConnIdle: func() *int {
if !redisStatsOK {
return nil
}
return intPtr(redisIdle)
}(),
DBConnActive: intPtr(active),
DBConnIdle: intPtr(idle),
GoroutineCount: intPtr(goroutines),
ConcurrencyQueueDepth: concurrencyQueueDepth,
}
return c.opsRepo.InsertSystemMetrics(ctx, input)
}
func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Context) *int {
if c == nil || c.accountRepo == nil || c.concurrencyService == nil {
return nil
}
if parentCtx == nil {
parentCtx = context.Background()
}
// Best-effort: never let concurrency sampling break the metrics collector.
ctx, cancel := context.WithTimeout(parentCtx, 2*time.Second)
defer cancel()
accounts, err := c.accountRepo.ListSchedulable(ctx)
if err != nil {
return nil
}
if len(accounts) == 0 {
zero := 0
return &zero
}
batch := make([]AccountWithConcurrency, 0, len(accounts))
for _, acc := range accounts {
if acc.ID <= 0 {
continue
}
maxConc := acc.Concurrency
if maxConc < 0 {
maxConc = 0
}
batch = append(batch, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: maxConc,
})
}
if len(batch) == 0 {
zero := 0
return &zero
}
loadMap, err := c.concurrencyService.GetAccountsLoadBatch(ctx, batch)
if err != nil {
return nil
}
var total int64
for _, info := range loadMap {
if info == nil || info.WaitingCount <= 0 {
continue
}
total += int64(info.WaitingCount)
}
if total < 0 {
total = 0
}
maxInt := int64(^uint(0) >> 1)
if total > maxInt {
total = maxInt
}
v := int(total)
return &v
}
type opsCollectedPercentiles struct {
p50 *int
p90 *int
p95 *int
p99 *int
avg *float64
max *int
}
func (c *OpsMetricsCollector) queryUsageCounts(ctx context.Context, start, end time.Time) (successCount int64, tokenConsumed int64, err error) {
q := `
SELECT
COALESCE(COUNT(*), 0) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2`
var tokens sql.NullInt64
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&successCount, &tokens); err != nil {
return 0, 0, err
}
if tokens.Valid {
tokenConsumed = tokens.Int64
}
return successCount, tokenConsumed, nil
}
func (c *OpsMetricsCollector) queryUsageLatency(ctx context.Context, start, end time.Time) (duration opsCollectedPercentiles, ttft opsCollectedPercentiles, err error) {
{
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
AVG(duration_ms) AS avg_ms,
MAX(duration_ms) AS max_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
AND duration_ms IS NOT NULL`
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return opsCollectedPercentiles{}, opsCollectedPercentiles{}, err
}
duration.p50 = floatToIntPtr(p50)
duration.p90 = floatToIntPtr(p90)
duration.p95 = floatToIntPtr(p95)
duration.p99 = floatToIntPtr(p99)
if avg.Valid {
v := roundTo1DP(avg.Float64)
duration.avg = &v
}
if max.Valid {
v := int(max.Int64)
duration.max = &v
}
}
{
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
AVG(first_token_ms) AS avg_ms,
MAX(first_token_ms) AS max_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
AND first_token_ms IS NOT NULL`
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return opsCollectedPercentiles{}, opsCollectedPercentiles{}, err
}
ttft.p50 = floatToIntPtr(p50)
ttft.p90 = floatToIntPtr(p90)
ttft.p95 = floatToIntPtr(p95)
ttft.p99 = floatToIntPtr(p99)
if avg.Valid {
v := roundTo1DP(avg.Float64)
ttft.avg = &v
}
if max.Valid {
v := int(max.Int64)
ttft.max = &v
}
}
return duration, ttft, nil
}
func (c *OpsMetricsCollector) queryErrorCounts(ctx context.Context, start, end time.Time) (
errorTotal int64,
businessLimited int64,
errorSLA int64,
upstreamExcl429529 int64,
upstream429 int64,
upstream529 int64,
err error,
) {
q := `
SELECT
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400), 0) AS error_total,
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited), 0) AS business_limited,
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited), 0) AS error_sla,
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)), 0) AS upstream_excl,
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429), 0) AS upstream_429,
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529), 0) AS upstream_529
FROM ops_error_logs
WHERE created_at >= $1 AND created_at < $2`
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(
&errorTotal,
&businessLimited,
&errorSLA,
&upstreamExcl429529,
&upstream429,
&upstream529,
); err != nil {
return 0, 0, 0, 0, 0, 0, err
}
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
}
type opsCollectedSystemStats struct {
cpuUsagePercent *float64
memoryUsedMB *int64
memoryTotalMB *int64
memoryUsagePercent *float64
}
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) (*opsCollectedSystemStats, error) {
out := &opsCollectedSystemStats{}
if ctx == nil {
ctx = context.Background()
}
sampleAt := time.Now().UTC()
// Prefer cgroup (container) metrics when available.
if cpuPct := c.tryCgroupCPUPercent(sampleAt); cpuPct != nil {
out.cpuUsagePercent = cpuPct
}
cgroupUsed, cgroupTotal, cgroupOK := readCgroupMemoryBytes()
if cgroupOK {
usedMB := int64(cgroupUsed / bytesPerMB)
out.memoryUsedMB = &usedMB
if cgroupTotal > 0 {
totalMB := int64(cgroupTotal / bytesPerMB)
out.memoryTotalMB = &totalMB
pct := roundTo1DP(float64(cgroupUsed) / float64(cgroupTotal) * 100)
out.memoryUsagePercent = &pct
}
}
// Fallback to host metrics if cgroup metrics are unavailable (or incomplete).
if out.cpuUsagePercent == nil {
if cpuPercents, err := cpu.PercentWithContext(ctx, 0, false); err == nil && len(cpuPercents) > 0 {
v := roundTo1DP(cpuPercents[0])
out.cpuUsagePercent = &v
}
}
// If total memory isn't available from cgroup (e.g. memory.max = "max"), fill total from host.
if out.memoryUsedMB == nil || out.memoryTotalMB == nil || out.memoryUsagePercent == nil {
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil && vm != nil {
if out.memoryUsedMB == nil {
usedMB := int64(vm.Used / bytesPerMB)
out.memoryUsedMB = &usedMB
}
if out.memoryTotalMB == nil {
totalMB := int64(vm.Total / bytesPerMB)
out.memoryTotalMB = &totalMB
}
if out.memoryUsagePercent == nil {
if out.memoryUsedMB != nil && out.memoryTotalMB != nil && *out.memoryTotalMB > 0 {
pct := roundTo1DP(float64(*out.memoryUsedMB) / float64(*out.memoryTotalMB) * 100)
out.memoryUsagePercent = &pct
} else {
pct := roundTo1DP(vm.UsedPercent)
out.memoryUsagePercent = &pct
}
}
}
}
return out, nil
}
func (c *OpsMetricsCollector) tryCgroupCPUPercent(now time.Time) *float64 {
usageNanos, ok := readCgroupCPUUsageNanos()
if !ok {
return nil
}
// Initialize baseline sample.
if c.lastCgroupCPUSampleAt.IsZero() {
c.lastCgroupCPUUsageNanos = usageNanos
c.lastCgroupCPUSampleAt = now
return nil
}
elapsed := now.Sub(c.lastCgroupCPUSampleAt)
if elapsed <= 0 {
c.lastCgroupCPUUsageNanos = usageNanos
c.lastCgroupCPUSampleAt = now
return nil
}
prev := c.lastCgroupCPUUsageNanos
c.lastCgroupCPUUsageNanos = usageNanos
c.lastCgroupCPUSampleAt = now
if usageNanos < prev {
// Counter reset (container restarted).
return nil
}
deltaUsageSec := float64(usageNanos-prev) / 1e9
elapsedSec := elapsed.Seconds()
if elapsedSec <= 0 {
return nil
}
cores := readCgroupCPULimitCores()
if cores <= 0 {
// Can't reliably normalize; skip and fall back to gopsutil.
return nil
}
pct := (deltaUsageSec / (elapsedSec * cores)) * 100
if pct < 0 {
pct = 0
}
// Clamp to avoid noise/jitter showing impossible values.
if pct > 100 {
pct = 100
}
v := roundTo1DP(pct)
return &v
}
func readCgroupMemoryBytes() (usedBytes uint64, totalBytes uint64, ok bool) {
// cgroup v2 (most common in modern containers)
if used, ok1 := readUintFile("/sys/fs/cgroup/memory.current"); ok1 {
usedBytes = used
rawMax, err := os.ReadFile("/sys/fs/cgroup/memory.max")
if err == nil {
s := strings.TrimSpace(string(rawMax))
if s != "" && s != "max" {
if v, err := strconv.ParseUint(s, 10, 64); err == nil {
totalBytes = v
}
}
}
return usedBytes, totalBytes, true
}
// cgroup v1 fallback
if used, ok1 := readUintFile("/sys/fs/cgroup/memory/memory.usage_in_bytes"); ok1 {
usedBytes = used
if limit, ok2 := readUintFile("/sys/fs/cgroup/memory/memory.limit_in_bytes"); ok2 {
// Some environments report a very large number when unlimited.
if limit > 0 && limit < (1<<60) {
totalBytes = limit
}
}
return usedBytes, totalBytes, true
}
return 0, 0, false
}
func readCgroupCPUUsageNanos() (usageNanos uint64, ok bool) {
// cgroup v2: cpu.stat has usage_usec
if raw, err := os.ReadFile("/sys/fs/cgroup/cpu.stat"); err == nil {
lines := strings.Split(string(raw), "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) != 2 {
continue
}
if fields[0] != "usage_usec" {
continue
}
v, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
continue
}
return v * 1000, true
}
}
// cgroup v1: cpuacct.usage is in nanoseconds
if v, ok := readUintFile("/sys/fs/cgroup/cpuacct/cpuacct.usage"); ok {
return v, true
}
return 0, false
}
func readCgroupCPULimitCores() float64 {
// cgroup v2: cpu.max => "<quota> <period>" or "max <period>"
if raw, err := os.ReadFile("/sys/fs/cgroup/cpu.max"); err == nil {
fields := strings.Fields(string(raw))
if len(fields) >= 2 && fields[0] != "max" {
quota, err1 := strconv.ParseFloat(fields[0], 64)
period, err2 := strconv.ParseFloat(fields[1], 64)
if err1 == nil && err2 == nil && quota > 0 && period > 0 {
return quota / period
}
}
}
// cgroup v1: cpu.cfs_quota_us / cpu.cfs_period_us
quota, okQuota := readIntFile("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")
period, okPeriod := readIntFile("/sys/fs/cgroup/cpu/cpu.cfs_period_us")
if okQuota && okPeriod && quota > 0 && period > 0 {
return float64(quota) / float64(period)
}
return 0
}
func readUintFile(path string) (uint64, bool) {
raw, err := os.ReadFile(path)
if err != nil {
return 0, false
}
s := strings.TrimSpace(string(raw))
if s == "" {
return 0, false
}
v, err := strconv.ParseUint(s, 10, 64)
if err != nil {
return 0, false
}
return v, true
}
func readIntFile(path string) (int64, bool) {
raw, err := os.ReadFile(path)
if err != nil {
return 0, false
}
s := strings.TrimSpace(string(raw))
if s == "" {
return 0, false
}
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, false
}
return v, true
}
func (c *OpsMetricsCollector) checkDB(ctx context.Context) bool {
if c == nil || c.db == nil {
return false
}
if ctx == nil {
ctx = context.Background()
}
var one int
if err := c.db.QueryRowContext(ctx, "SELECT 1").Scan(&one); err != nil {
return false
}
return one == 1
}
func (c *OpsMetricsCollector) checkRedis(ctx context.Context) bool {
if c == nil || c.redisClient == nil {
return false
}
if ctx == nil {
ctx = context.Background()
}
return c.redisClient.Ping(ctx).Err() == nil
}
func (c *OpsMetricsCollector) redisPoolStats() (total int, idle int, ok bool) {
if c == nil || c.redisClient == nil {
return 0, 0, false
}
stats := c.redisClient.PoolStats()
if stats == nil {
return 0, 0, false
}
return int(stats.TotalConns), int(stats.IdleConns), true
}
func (c *OpsMetricsCollector) dbPoolStats() (active int, idle int) {
if c == nil || c.db == nil {
return 0, 0
}
stats := c.db.Stats()
return stats.InUse, stats.Idle
}
var opsMetricsCollectorReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
func (c *OpsMetricsCollector) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if c == nil || c.redisClient == nil {
return nil, true
}
if ctx == nil {
ctx = context.Background()
}
ok, err := c.redisClient.SetNX(ctx, opsMetricsCollectorLeaderLockKey, c.instanceID, opsMetricsCollectorLeaderLockTTL).Result()
if err != nil {
// Prefer fail-closed to avoid stampeding the database when Redis is flaky.
// Fallback to a DB advisory lock when Redis is present but unavailable.
release, ok := tryAcquireDBAdvisoryLock(ctx, c.db, opsMetricsCollectorAdvisoryLockID)
if !ok {
c.maybeLogSkip()
return nil, false
}
return release, true
}
if !ok {
c.maybeLogSkip()
return nil, false
}
release := func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_, _ = opsMetricsCollectorReleaseScript.Run(ctx, c.redisClient, []string{opsMetricsCollectorLeaderLockKey}, c.instanceID).Result()
}
return release, true
}
func (c *OpsMetricsCollector) maybeLogSkip() {
c.skipLogMu.Lock()
defer c.skipLogMu.Unlock()
now := time.Now()
if !c.skipLogAt.IsZero() && now.Sub(c.skipLogAt) < time.Minute {
return
}
c.skipLogAt = now
log.Printf("[OpsMetricsCollector] leader lock held by another instance; skipping")
}
func floatToIntPtr(v sql.NullFloat64) *int {
if !v.Valid {
return nil
}
n := int(math.Round(v.Float64))
return &n
}
func roundTo1DP(v float64) float64 {
return math.Round(v*10) / 10
}
func truncateString(s string, max int) string {
if max <= 0 {
return ""
}
if len(s) <= max {
return s
}
cut := s[:max]
for len(cut) > 0 && !utf8.ValidString(cut) {
cut = cut[:len(cut)-1]
}
return cut
}
func boolPtr(v bool) *bool {
out := v
return &out
}
func intPtr(v int) *int {
out := v
return &out
}
func float64Ptr(v float64) *float64 {
out := v
return &out
}

View File

@@ -0,0 +1,169 @@
package service
import "time"
type OpsErrorLog struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
// Standardized classification
// - phase: request|auth|routing|upstream|network|internal
// - owner: client|provider|platform
// - source: client_request|upstream_http|gateway
Phase string `json:"phase"`
Type string `json:"type"`
Owner string `json:"error_owner"`
Source string `json:"error_source"`
Severity string `json:"severity"`
StatusCode int `json:"status_code"`
Platform string `json:"platform"`
Model string `json:"model"`
IsRetryable bool `json:"is_retryable"`
RetryCount int `json:"retry_count"`
Resolved bool `json:"resolved"`
ResolvedAt *time.Time `json:"resolved_at"`
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
ResolvedByUserName string `json:"resolved_by_user_name"`
ResolvedRetryID *int64 `json:"resolved_retry_id"`
ResolvedStatusRaw string `json:"-"`
ClientRequestID string `json:"client_request_id"`
RequestID string `json:"request_id"`
Message string `json:"message"`
UserID *int64 `json:"user_id"`
UserEmail string `json:"user_email"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
AccountName string `json:"account_name"`
GroupID *int64 `json:"group_id"`
GroupName string `json:"group_name"`
ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"`
Stream bool `json:"stream"`
}
type OpsErrorLogDetail struct {
OpsErrorLog
ErrorBody string `json:"error_body"`
UserAgent string `json:"user_agent"`
// Upstream context (optional)
UpstreamStatusCode *int `json:"upstream_status_code,omitempty"`
UpstreamErrorMessage string `json:"upstream_error_message,omitempty"`
UpstreamErrorDetail string `json:"upstream_error_detail,omitempty"`
UpstreamErrors string `json:"upstream_errors,omitempty"` // JSON array (string) for display/parsing
// Timings (optional)
AuthLatencyMs *int64 `json:"auth_latency_ms"`
RoutingLatencyMs *int64 `json:"routing_latency_ms"`
UpstreamLatencyMs *int64 `json:"upstream_latency_ms"`
ResponseLatencyMs *int64 `json:"response_latency_ms"`
TimeToFirstTokenMs *int64 `json:"time_to_first_token_ms"`
// Retry context
RequestBody string `json:"request_body"`
RequestBodyTruncated bool `json:"request_body_truncated"`
RequestBodyBytes *int `json:"request_body_bytes"`
RequestHeaders string `json:"request_headers,omitempty"`
// vNext metric semantics
IsBusinessLimited bool `json:"is_business_limited"`
}
type OpsErrorLogFilter struct {
StartTime *time.Time
EndTime *time.Time
Platform string
GroupID *int64
AccountID *int64
StatusCodes []int
StatusCodesOther bool
Phase string
Owner string
Source string
Resolved *bool
Query string
UserQuery string // Search by user email
// Optional correlation keys for exact matching.
RequestID string
ClientRequestID string
// View controls error categorization for list endpoints.
// - errors: show actionable errors (exclude business-limited / 429 / 529)
// - excluded: only show excluded errors
// - all: show everything
View string
Page int
PageSize int
}
type OpsErrorLogList struct {
Errors []*OpsErrorLog `json:"errors"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
type OpsRetryAttempt struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
RequestedByUserID int64 `json:"requested_by_user_id"`
SourceErrorID int64 `json:"source_error_id"`
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
PinnedAccountName string `json:"pinned_account_name"`
Status string `json:"status"`
StartedAt *time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at"`
DurationMs *int64 `json:"duration_ms"`
// Persisted execution results (best-effort)
Success *bool `json:"success"`
HTTPStatusCode *int `json:"http_status_code"`
UpstreamRequestID *string `json:"upstream_request_id"`
UsedAccountID *int64 `json:"used_account_id"`
UsedAccountName string `json:"used_account_name"`
ResponsePreview *string `json:"response_preview"`
ResponseTruncated *bool `json:"response_truncated"`
// Optional correlation
ResultRequestID *string `json:"result_request_id"`
ResultErrorID *int64 `json:"result_error_id"`
ErrorMessage *string `json:"error_message"`
}
type OpsRetryResult struct {
AttemptID int64 `json:"attempt_id"`
Mode string `json:"mode"`
Status string `json:"status"`
PinnedAccountID *int64 `json:"pinned_account_id"`
UsedAccountID *int64 `json:"used_account_id"`
HTTPStatusCode int `json:"http_status_code"`
UpstreamRequestID string `json:"upstream_request_id"`
ResponsePreview string `json:"response_preview"`
ResponseTruncated bool `json:"response_truncated"`
ErrorMessage string `json:"error_message"`
StartedAt time.Time `json:"started_at"`
FinishedAt time.Time `json:"finished_at"`
DurationMs int64 `json:"duration_ms"`
}

View File

@@ -0,0 +1,263 @@
package service
import (
"context"
"time"
)
type OpsRepository interface {
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
// Lightweight window stats (for realtime WS / quick sampling).
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
// Lightweight realtime traffic summary (for the Ops dashboard header card).
GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error)
GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error)
GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error)
GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error)
GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error)
GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error)
InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error
GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error)
UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error
ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error)
// Alerts (rules + events)
ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error)
CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error)
UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error)
DeleteAlertRule(ctx context.Context, id int64) error
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
// Alert silences
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error)
GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error)
}
type OpsInsertErrorLogInput struct {
RequestID string
ClientRequestID string
UserID *int64
APIKeyID *int64
AccountID *int64
GroupID *int64
ClientIP *string
Platform string
Model string
RequestPath string
Stream bool
UserAgent string
ErrorPhase string
ErrorType string
Severity string
StatusCode int
IsBusinessLimited bool
IsCountTokens bool // 是否为 count_tokens 请求
ErrorMessage string
ErrorBody string
ErrorSource string
ErrorOwner string
UpstreamStatusCode *int
UpstreamErrorMessage *string
UpstreamErrorDetail *string
// UpstreamErrors captures all upstream error attempts observed during handling this request.
// It is populated during request processing (gin context) and sanitized+serialized by OpsService.
UpstreamErrors []*OpsUpstreamErrorEvent
// UpstreamErrorsJSON is the sanitized JSON string stored into ops_error_logs.upstream_errors.
// It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON *string
TimeToFirstTokenMs *int64
RequestBodyJSON *string // sanitized json string (not raw bytes)
RequestBodyTruncated bool
RequestBodyBytes *int
RequestHeadersJSON *string // optional json string
IsRetryable bool
RetryCount int
CreatedAt time.Time
}
type OpsInsertRetryAttemptInput struct {
RequestedByUserID int64
SourceErrorID int64
Mode string
PinnedAccountID *int64
// running|queued etc.
Status string
StartedAt time.Time
}
type OpsUpdateRetryAttemptInput struct {
ID int64
// succeeded|failed
Status string
FinishedAt time.Time
DurationMs int64
// Persisted execution results (best-effort)
Success *bool
HTTPStatusCode *int
UpstreamRequestID *string
UsedAccountID *int64
ResponsePreview *string
ResponseTruncated *bool
// Optional correlation (legacy fields kept)
ResultRequestID *string
ResultErrorID *int64
ErrorMessage *string
}
type OpsInsertSystemMetricsInput struct {
CreatedAt time.Time
WindowMinutes int
Platform *string
GroupID *int64
SuccessCount int64
ErrorCountTotal int64
BusinessLimitedCount int64
ErrorCountSLA int64
UpstreamErrorCountExcl429529 int64
Upstream429Count int64
Upstream529Count int64
TokenConsumed int64
QPS *float64
TPS *float64
DurationP50Ms *int
DurationP90Ms *int
DurationP95Ms *int
DurationP99Ms *int
DurationAvgMs *float64
DurationMaxMs *int
TTFTP50Ms *int
TTFTP90Ms *int
TTFTP95Ms *int
TTFTP99Ms *int
TTFTAvgMs *float64
TTFTMaxMs *int
CPUUsagePercent *float64
MemoryUsedMB *int64
MemoryTotalMB *int64
MemoryUsagePercent *float64
DBOK *bool
RedisOK *bool
RedisConnTotal *int
RedisConnIdle *int
DBConnActive *int
DBConnIdle *int
DBConnWaiting *int
GoroutineCount *int
ConcurrencyQueueDepth *int
}
type OpsSystemMetricsSnapshot struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
WindowMinutes int `json:"window_minutes"`
CPUUsagePercent *float64 `json:"cpu_usage_percent"`
MemoryUsedMB *int64 `json:"memory_used_mb"`
MemoryTotalMB *int64 `json:"memory_total_mb"`
MemoryUsagePercent *float64 `json:"memory_usage_percent"`
DBOK *bool `json:"db_ok"`
RedisOK *bool `json:"redis_ok"`
// Config-derived limits (best-effort). These are not historical metrics; they help UI render "current vs max".
DBMaxOpenConns *int `json:"db_max_open_conns"`
RedisPoolSize *int `json:"redis_pool_size"`
RedisConnTotal *int `json:"redis_conn_total"`
RedisConnIdle *int `json:"redis_conn_idle"`
DBConnActive *int `json:"db_conn_active"`
DBConnIdle *int `json:"db_conn_idle"`
DBConnWaiting *int `json:"db_conn_waiting"`
GoroutineCount *int `json:"goroutine_count"`
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
}
type OpsUpsertJobHeartbeatInput struct {
JobName string
LastRunAt *time.Time
LastSuccessAt *time.Time
LastErrorAt *time.Time
LastError *string
LastDurationMs *int64
// LastResult is an optional human-readable summary of the last successful run.
LastResult *string
}
type OpsJobHeartbeat struct {
JobName string `json:"job_name"`
LastRunAt *time.Time `json:"last_run_at"`
LastSuccessAt *time.Time `json:"last_success_at"`
LastErrorAt *time.Time `json:"last_error_at"`
LastError *string `json:"last_error"`
LastDurationMs *int64 `json:"last_duration_ms"`
LastResult *string `json:"last_result"`
UpdatedAt time.Time `json:"updated_at"`
}
type OpsWindowStats struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
SuccessCount int64 `json:"success_count"`
ErrorCountTotal int64 `json:"error_count_total"`
TokenConsumed int64 `json:"token_consumed"`
}

View File

@@ -0,0 +1,40 @@
package service
import (
"errors"
"strings"
)
type OpsQueryMode string
const (
OpsQueryModeAuto OpsQueryMode = "auto"
OpsQueryModeRaw OpsQueryMode = "raw"
OpsQueryModePreagg OpsQueryMode = "preagg"
)
// ErrOpsPreaggregatedNotPopulated indicates that raw logs exist for a window, but the
// pre-aggregation tables are not populated yet. This is primarily used to implement
// the forced `preagg` mode UX.
var ErrOpsPreaggregatedNotPopulated = errors.New("ops pre-aggregated tables not populated")
func ParseOpsQueryMode(raw string) OpsQueryMode {
v := strings.ToLower(strings.TrimSpace(raw))
switch v {
case string(OpsQueryModeRaw):
return OpsQueryModeRaw
case string(OpsQueryModePreagg):
return OpsQueryModePreagg
default:
return OpsQueryModeAuto
}
}
func (m OpsQueryMode) IsValid() bool {
switch m {
case OpsQueryModeAuto, OpsQueryModeRaw, OpsQueryModePreagg:
return true
default:
return false
}
}

View File

@@ -0,0 +1,36 @@
package service
import (
"context"
"errors"
"strings"
)
// IsRealtimeMonitoringEnabled returns true when realtime ops features are enabled.
//
// This is a soft switch controlled by the DB setting `ops_realtime_monitoring_enabled`,
// and it is also gated by the hard switch/soft switch of overall ops monitoring.
func (s *OpsService) IsRealtimeMonitoringEnabled(ctx context.Context) bool {
if !s.IsMonitoringEnabled(ctx) {
return false
}
if s.settingRepo == nil {
return true
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRealtimeMonitoringEnabled)
if err != nil {
// Default enabled when key is missing; fail-open on transient errors.
if errors.Is(err, ErrSettingNotFound) {
return true
}
return true
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
return false
default:
return true
}
}

View File

@@ -0,0 +1,81 @@
package service
import "time"
// PlatformConcurrencyInfo aggregates concurrency usage by platform.
type PlatformConcurrencyInfo struct {
Platform string `json:"platform"`
CurrentInUse int64 `json:"current_in_use"`
MaxCapacity int64 `json:"max_capacity"`
LoadPercentage float64 `json:"load_percentage"`
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// GroupConcurrencyInfo aggregates concurrency usage by group.
//
// Note: one account can belong to multiple groups; group totals are therefore not additive across groups.
type GroupConcurrencyInfo struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
CurrentInUse int64 `json:"current_in_use"`
MaxCapacity int64 `json:"max_capacity"`
LoadPercentage float64 `json:"load_percentage"`
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// AccountConcurrencyInfo represents real-time concurrency usage for a single account.
type AccountConcurrencyInfo struct {
AccountID int64 `json:"account_id"`
AccountName string `json:"account_name"`
Platform string `json:"platform"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
CurrentInUse int64 `json:"current_in_use"`
MaxCapacity int64 `json:"max_capacity"`
LoadPercentage float64 `json:"load_percentage"`
WaitingInQueue int64 `json:"waiting_in_queue"`
}
// PlatformAvailability aggregates account availability by platform.
type PlatformAvailability struct {
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// GroupAvailability aggregates account availability by group.
type GroupAvailability struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Platform string `json:"platform"`
TotalAccounts int64 `json:"total_accounts"`
AvailableCount int64 `json:"available_count"`
RateLimitCount int64 `json:"rate_limit_count"`
ErrorCount int64 `json:"error_count"`
}
// AccountAvailability represents current availability for a single account.
type AccountAvailability struct {
AccountID int64 `json:"account_id"`
AccountName string `json:"account_name"`
Platform string `json:"platform"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Status string `json:"status"`
IsAvailable bool `json:"is_available"`
IsRateLimited bool `json:"is_rate_limited"`
IsOverloaded bool `json:"is_overloaded"`
HasError bool `json:"has_error"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
OverloadUntil *time.Time `json:"overload_until"`
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
ErrorMessage string `json:"error_message"`
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
}

View File

@@ -0,0 +1,36 @@
package service
import (
"context"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the provided window.
// This is used by the Ops dashboard "Realtime Traffic" card and is intentionally lightweight.
func (s *OpsService) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if filter == nil {
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
}
if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
}
if filter.EndTime.Sub(filter.StartTime) > time.Hour {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_TOO_LARGE", "invalid time range: max window is 1 hour")
}
// Realtime traffic summary always uses raw logs (minute granularity peaks).
filter.QueryMode = OpsQueryModeRaw
return s.opsRepo.GetRealtimeTrafficSummary(ctx, filter)
}

View File

@@ -0,0 +1,19 @@
package service
import "time"
// OpsRealtimeTrafficSummary is a lightweight summary used by the Ops dashboard "Realtime Traffic" card.
// It reports QPS/TPS current/peak/avg for the requested time window.
type OpsRealtimeTrafficSummary struct {
// Window is a normalized label (e.g. "1min", "5min", "30min", "1h").
Window string `json:"window"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
QPS OpsRateSummary `json:"qps"`
TPS OpsRateSummary `json:"tps"`
}

View File

@@ -0,0 +1,151 @@
package service
import (
"context"
"time"
)
type OpsRequestKind string
const (
OpsRequestKindSuccess OpsRequestKind = "success"
OpsRequestKindError OpsRequestKind = "error"
)
// OpsRequestDetail is a request-level view across success (usage_logs) and error (ops_error_logs).
// It powers "request drilldown" UIs without exposing full request bodies for successful requests.
type OpsRequestDetail struct {
Kind OpsRequestKind `json:"kind"`
CreatedAt time.Time `json:"created_at"`
RequestID string `json:"request_id"`
Platform string `json:"platform,omitempty"`
Model string `json:"model,omitempty"`
DurationMs *int `json:"duration_ms,omitempty"`
StatusCode *int `json:"status_code,omitempty"`
// When Kind == "error", ErrorID links to /admin/ops/errors/:id.
ErrorID *int64 `json:"error_id,omitempty"`
Phase string `json:"phase,omitempty"`
Severity string `json:"severity,omitempty"`
Message string `json:"message,omitempty"`
UserID *int64 `json:"user_id,omitempty"`
APIKeyID *int64 `json:"api_key_id,omitempty"`
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Stream bool `json:"stream"`
}
type OpsRequestDetailFilter struct {
StartTime *time.Time
EndTime *time.Time
// kind: success|error|all
Kind string
Platform string
GroupID *int64
UserID *int64
APIKeyID *int64
AccountID *int64
Model string
RequestID string
Query string
MinDurationMs *int
MaxDurationMs *int
// Sort: created_at_desc (default) or duration_desc.
Sort string
Page int
PageSize int
}
func (f *OpsRequestDetailFilter) Normalize() (page, pageSize int, startTime, endTime time.Time) {
page = 1
pageSize = 50
endTime = time.Now()
startTime = endTime.Add(-1 * time.Hour)
if f == nil {
return page, pageSize, startTime, endTime
}
if f.Page > 0 {
page = f.Page
}
if f.PageSize > 0 {
pageSize = f.PageSize
}
if pageSize > 100 {
pageSize = 100
}
if f.EndTime != nil {
endTime = *f.EndTime
}
if f.StartTime != nil {
startTime = *f.StartTime
} else if f.EndTime != nil {
startTime = endTime.Add(-1 * time.Hour)
}
if startTime.After(endTime) {
startTime, endTime = endTime, startTime
}
return page, pageSize, startTime, endTime
}
type OpsRequestDetailList struct {
Items []*OpsRequestDetail `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
func (s *OpsService) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) (*OpsRequestDetailList, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return &OpsRequestDetailList{
Items: []*OpsRequestDetail{},
Total: 0,
Page: 1,
PageSize: 50,
}, nil
}
page, pageSize, startTime, endTime := filter.Normalize()
filterCopy := &OpsRequestDetailFilter{}
if filter != nil {
*filterCopy = *filter
}
filterCopy.Page = page
filterCopy.PageSize = pageSize
filterCopy.StartTime = &startTime
filterCopy.EndTime = &endTime
items, total, err := s.opsRepo.ListRequestDetails(ctx, filterCopy)
if err != nil {
return nil, err
}
if items == nil {
items = []*OpsRequestDetail{}
}
return &OpsRequestDetailList{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
}, nil
}

View File

@@ -0,0 +1,720 @@
package service
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
)
const (
OpsRetryModeClient = "client"
OpsRetryModeUpstream = "upstream"
)
const (
opsRetryStatusRunning = "running"
opsRetryStatusSucceeded = "succeeded"
opsRetryStatusFailed = "failed"
)
const (
opsRetryTimeout = 60 * time.Second
opsRetryCaptureBytesLimit = 64 * 1024
opsRetryResponsePreviewMax = 8 * 1024
opsRetryMinIntervalPerError = 10 * time.Second
opsRetryMaxAccountSwitches = 3
)
var opsRetryRequestHeaderAllowlist = map[string]bool{
"anthropic-beta": true,
"anthropic-version": true,
}
type opsRetryRequestType string
const (
opsRetryTypeMessages opsRetryRequestType = "messages"
opsRetryTypeOpenAI opsRetryRequestType = "openai_responses"
opsRetryTypeGeminiV1B opsRetryRequestType = "gemini_v1beta"
)
type limitedResponseWriter struct {
header http.Header
wroteHeader bool
limit int
totalWritten int64
buf bytes.Buffer
}
func newLimitedResponseWriter(limit int) *limitedResponseWriter {
if limit <= 0 {
limit = 1
}
return &limitedResponseWriter{
header: make(http.Header),
limit: limit,
}
}
func (w *limitedResponseWriter) Header() http.Header {
return w.header
}
func (w *limitedResponseWriter) WriteHeader(statusCode int) {
if w.wroteHeader {
return
}
w.wroteHeader = true
}
func (w *limitedResponseWriter) Write(p []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
w.totalWritten += int64(len(p))
if w.buf.Len() < w.limit {
remaining := w.limit - w.buf.Len()
if len(p) > remaining {
_, _ = w.buf.Write(p[:remaining])
} else {
_, _ = w.buf.Write(p)
}
}
// Pretend we wrote everything to avoid upstream/client code treating it as an error.
return len(p), nil
}
func (w *limitedResponseWriter) Flush() {}
func (w *limitedResponseWriter) bodyBytes() []byte {
return w.buf.Bytes()
}
func (w *limitedResponseWriter) truncated() bool {
return w.totalWritten > int64(w.limit)
}
const (
OpsRetryModeUpstreamEvent = "upstream_event"
)
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
mode = strings.ToLower(strings.TrimSpace(mode))
switch mode {
case OpsRetryModeClient, OpsRetryModeUpstream:
default:
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
}
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if errorLog == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
if strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
}
var pinned *int64
if mode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID
} else {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
}
}
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
}
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
// idx is 0-based. It always pins the original event account_id.
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if idx < 0 {
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
}
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if errorLog == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
if err != nil {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
}
if idx >= len(events) {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
}
ev := events[idx]
if ev == nil {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
}
if ev.AccountID <= 0 {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
}
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
if upstreamBody == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
}
override := *errorLog
override.RequestBody = upstreamBody
pinned := ev.AccountID
// Persist as upstream_event, execute as upstream pinned retry.
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
}
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
}
if latest != nil {
if strings.EqualFold(latest.Status, opsRetryStatusRunning) || strings.EqualFold(latest.Status, "queued") {
return nil, infraerrors.Conflict("OPS_RETRY_IN_PROGRESS", "A retry is already in progress for this error")
}
lastAttemptAt := latest.CreatedAt
if latest.FinishedAt != nil && !latest.FinishedAt.IsZero() {
lastAttemptAt = *latest.FinishedAt
} else if latest.StartedAt != nil && !latest.StartedAt.IsZero() {
lastAttemptAt = *latest.StartedAt
}
if time.Since(lastAttemptAt) < opsRetryMinIntervalPerError {
return nil, infraerrors.Conflict("OPS_RETRY_TOO_FREQUENT", "Please wait before retrying this error again")
}
}
if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
}
var pinned *int64
if execMode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID
} else {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
}
}
startedAt := time.Now()
attemptID, err := s.opsRepo.InsertRetryAttempt(ctx, &OpsInsertRetryAttemptInput{
RequestedByUserID: requestedByUserID,
SourceErrorID: errorID,
Mode: mode,
PinnedAccountID: pinned,
Status: opsRetryStatusRunning,
StartedAt: startedAt,
})
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && string(pqErr.Code) == "23505" {
return nil, infraerrors.Conflict("OPS_RETRY_IN_PROGRESS", "A retry is already in progress for this error")
}
return nil, infraerrors.InternalServer("OPS_RETRY_CREATE_ATTEMPT_FAILED", "Failed to create retry attempt").WithCause(err)
}
result := &OpsRetryResult{
AttemptID: attemptID,
Mode: mode,
Status: opsRetryStatusFailed,
PinnedAccountID: pinned,
HTTPStatusCode: 0,
UpstreamRequestID: "",
ResponsePreview: "",
ResponseTruncated: false,
ErrorMessage: "",
StartedAt: startedAt,
}
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
defer cancel()
execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
finishedAt := time.Now()
result.FinishedAt = finishedAt
result.DurationMs = finishedAt.Sub(startedAt).Milliseconds()
if execRes != nil {
result.Status = execRes.status
result.UsedAccountID = execRes.usedAccountID
result.HTTPStatusCode = execRes.httpStatusCode
result.UpstreamRequestID = execRes.upstreamRequestID
result.ResponsePreview = execRes.responsePreview
result.ResponseTruncated = execRes.responseTruncated
result.ErrorMessage = execRes.errorMessage
}
updateCtx, updateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer updateCancel()
var updateErrMsg *string
if strings.TrimSpace(result.ErrorMessage) != "" {
msg := result.ErrorMessage
updateErrMsg = &msg
}
// Keep legacy result_request_id empty; use upstream_request_id instead.
var resultRequestID *string
finalStatus := result.Status
if strings.TrimSpace(finalStatus) == "" {
finalStatus = opsRetryStatusFailed
}
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
httpStatus := result.HTTPStatusCode
upstreamReqID := result.UpstreamRequestID
usedAccountID := result.UsedAccountID
preview := result.ResponsePreview
truncated := result.ResponseTruncated
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
ID: attemptID,
Status: finalStatus,
FinishedAt: finishedAt,
DurationMs: result.DurationMs,
Success: &success,
HTTPStatusCode: &httpStatus,
UpstreamRequestID: &upstreamReqID,
UsedAccountID: usedAccountID,
ResponsePreview: &preview,
ResponseTruncated: &truncated,
ResultRequestID: resultRequestID,
ErrorMessage: updateErrMsg,
}); err != nil {
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
} else if success {
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
}
}
return result, nil
}
type opsRetryExecution struct {
status string
usedAccountID *int64
httpStatusCode int
upstreamRequestID string
responsePreview string
responseTruncated bool
errorMessage string
}
func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDetail, mode string, pinnedAccountID *int64) *opsRetryExecution {
if errorLog == nil {
return &opsRetryExecution{
status: opsRetryStatusFailed,
errorMessage: "missing error log",
}
}
reqType := detectOpsRetryType(errorLog.RequestPath)
bodyBytes := []byte(errorLog.RequestBody)
switch reqType {
case opsRetryTypeMessages:
bodyBytes = FilterThinkingBlocksForRetry(bodyBytes)
case opsRetryTypeOpenAI, opsRetryTypeGeminiV1B:
// No-op
}
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpsRetryModeUpstream:
if pinnedAccountID == nil || *pinnedAccountID <= 0 {
return &opsRetryExecution{
status: opsRetryStatusFailed,
errorMessage: "pinned_account_id required for upstream retry",
}
}
return s.executePinnedRetry(ctx, reqType, errorLog, bodyBytes, *pinnedAccountID)
case OpsRetryModeClient:
return s.executeClientRetry(ctx, reqType, errorLog, bodyBytes)
default:
return &opsRetryExecution{
status: opsRetryStatusFailed,
errorMessage: "invalid retry mode",
}
}
}
func detectOpsRetryType(path string) opsRetryRequestType {
p := strings.ToLower(strings.TrimSpace(path))
switch {
case strings.Contains(p, "/responses"):
return opsRetryTypeOpenAI
case strings.Contains(p, "/v1beta/"):
return opsRetryTypeGeminiV1B
default:
return opsRetryTypeMessages
}
}
func (s *OpsService) executePinnedRetry(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte, pinnedAccountID int64) *opsRetryExecution {
if s.accountRepo == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account repository not available"}
}
account, err := s.accountRepo.GetByID(ctx, pinnedAccountID)
if err != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: fmt.Sprintf("account not found: %v", err)}
}
if account == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account not found"}
}
if !account.IsSchedulable() {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account is not schedulable"}
}
if errorLog.GroupID != nil && *errorLog.GroupID > 0 {
if !containsInt64(account.GroupIDs, *errorLog.GroupID) {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "pinned account is not in the same group as the original request"}
}
}
var release func()
if s.concurrencyService != nil {
acq, err := s.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: fmt.Sprintf("acquire account slot failed: %v", err)}
}
if acq == nil || !acq.Acquired {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account concurrency limit reached"}
}
release = acq.ReleaseFunc
}
if release != nil {
defer release()
}
usedID := account.ID
exec := s.executeWithAccount(ctx, reqType, errorLog, body, account)
exec.usedAccountID = &usedID
if exec.status == "" {
exec.status = opsRetryStatusFailed
}
return exec
}
func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) *opsRetryExecution {
groupID := errorLog.GroupID
if groupID == nil || *groupID <= 0 {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "group_id missing; cannot reselect account"}
}
model, stream, parsedErr := extractRetryModelAndStream(reqType, errorLog, body)
if parsedErr != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: parsedErr.Error()}
}
_ = stream
excluded := make(map[int64]struct{})
switches := 0
for {
if switches >= opsRetryMaxAccountSwitches {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "retry failed after exhausting account failovers"}
}
selection, selErr := s.selectAccountForRetry(ctx, reqType, groupID, model, excluded)
if selErr != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()}
}
if selection == nil || selection.Account == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "no available accounts"}
}
account := selection.Account
if !selection.Acquired || selection.ReleaseFunc == nil {
excluded[account.ID] = struct{}{}
switches++
continue
}
exec := func() *opsRetryExecution {
defer selection.ReleaseFunc()
return s.executeWithAccount(ctx, reqType, errorLog, body, account)
}()
if exec != nil {
if exec.status == opsRetryStatusSucceeded {
usedID := account.ID
exec.usedAccountID = &usedID
return exec
}
// If the gateway services ask for failover, try another account.
if s.isFailoverError(exec.errorMessage) {
excluded[account.ID] = struct{}{}
switches++
continue
}
usedID := account.ID
exec.usedAccountID = &usedID
return exec
}
excluded[account.ID] = struct{}{}
switches++
}
}
func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetryRequestType, groupID *int64, model string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
switch reqType {
case opsRetryTypeOpenAI:
if s.openAIGatewayService == nil {
return nil, fmt.Errorf("openai gateway service not available")
}
return s.openAIGatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
case opsRetryTypeGeminiV1B, opsRetryTypeMessages:
if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available")
}
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
}
}
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
switch reqType {
case opsRetryTypeMessages:
parsed, parseErr := ParseGatewayRequest(body)
if parseErr != nil {
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
}
return parsed.Model, parsed.Stream, nil
case opsRetryTypeOpenAI:
var v struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &v); err != nil {
return "", false, fmt.Errorf("failed to parse openai request body: %w", err)
}
return strings.TrimSpace(v.Model), v.Stream, nil
case opsRetryTypeGeminiV1B:
if strings.TrimSpace(errorLog.Model) == "" {
return "", false, fmt.Errorf("missing model for gemini v1beta retry")
}
return strings.TrimSpace(errorLog.Model), errorLog.Stream, nil
default:
return "", false, fmt.Errorf("unsupported retry type: %s", reqType)
}
}
func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte, account *Account) *opsRetryExecution {
if account == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "missing account"}
}
c, w := newOpsRetryContext(ctx, errorLog)
var err error
switch reqType {
case opsRetryTypeOpenAI:
if s.openAIGatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "openai gateway service not available"}
}
_, err = s.openAIGatewayService.Forward(ctx, c, account, body)
case opsRetryTypeGeminiV1B:
if s.geminiCompatService == nil || s.antigravityGatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini services not available"}
}
modelName := strings.TrimSpace(errorLog.Model)
action := "generateContent"
if errorLog.Stream {
action = "streamGenerateContent"
}
if account.Platform == PlatformAntigravity {
_, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body)
} else {
_, err = s.geminiCompatService.ForwardNative(ctx, c, account, modelName, action, errorLog.Stream, body)
}
case opsRetryTypeMessages:
switch account.Platform {
case PlatformAntigravity:
if s.antigravityGatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "antigravity gateway service not available"}
}
_, err = s.antigravityGatewayService.Forward(ctx, c, account, body)
case PlatformGemini:
if s.geminiCompatService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini gateway service not available"}
}
_, err = s.geminiCompatService.Forward(ctx, c, account, body)
default:
if s.gatewayService == nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
}
parsedReq, parseErr := ParseGatewayRequest(body)
if parseErr != nil {
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
}
_, err = s.gatewayService.Forward(ctx, c, account, parsedReq)
}
default:
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "unsupported retry type"}
}
statusCode := http.StatusOK
if c != nil && c.Writer != nil {
statusCode = c.Writer.Status()
}
upstreamReqID := extractUpstreamRequestID(c)
preview, truncated := extractResponsePreview(w)
exec := &opsRetryExecution{
status: opsRetryStatusFailed,
httpStatusCode: statusCode,
upstreamRequestID: upstreamReqID,
responsePreview: preview,
responseTruncated: truncated,
errorMessage: "",
}
if err == nil && statusCode < 400 {
exec.status = opsRetryStatusSucceeded
return exec
}
if err != nil {
exec.errorMessage = err.Error()
} else {
exec.errorMessage = fmt.Sprintf("upstream returned status %d", statusCode)
}
return exec
}
func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.Context, *limitedResponseWriter) {
w := newLimitedResponseWriter(opsRetryCaptureBytesLimit)
c, _ := gin.CreateTestContext(w)
path := "/"
if errorLog != nil && strings.TrimSpace(errorLog.RequestPath) != "" {
path = errorLog.RequestPath
}
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost"+path, bytes.NewReader(nil))
req.Header.Set("content-type", "application/json")
if errorLog != nil && strings.TrimSpace(errorLog.UserAgent) != "" {
req.Header.Set("user-agent", errorLog.UserAgent)
}
// Restore a minimal, whitelisted subset of request headers to improve retry fidelity
// (e.g. anthropic-beta / anthropic-version). Never replay auth credentials.
if errorLog != nil && strings.TrimSpace(errorLog.RequestHeaders) != "" {
var stored map[string]string
if err := json.Unmarshal([]byte(errorLog.RequestHeaders), &stored); err == nil {
for k, v := range stored {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if !opsRetryRequestHeaderAllowlist[strings.ToLower(key)] {
continue
}
val := strings.TrimSpace(v)
if val == "" {
continue
}
req.Header.Set(key, val)
}
}
}
c.Request = req
return c, w
}
func extractUpstreamRequestID(c *gin.Context) string {
if c == nil || c.Writer == nil {
return ""
}
h := c.Writer.Header()
if h == nil {
return ""
}
for _, key := range []string{"x-request-id", "X-Request-Id", "X-Request-ID"} {
if v := strings.TrimSpace(h.Get(key)); v != "" {
return v
}
}
return ""
}
func extractResponsePreview(w *limitedResponseWriter) (preview string, truncated bool) {
if w == nil {
return "", false
}
b := bytes.TrimSpace(w.bodyBytes())
if len(b) == 0 {
return "", w.truncated()
}
if len(b) > opsRetryResponsePreviewMax {
return string(b[:opsRetryResponsePreviewMax]), true
}
return string(b), w.truncated()
}
func containsInt64(items []int64, needle int64) bool {
for _, v := range items {
if v == needle {
return true
}
}
return false
}
func (s *OpsService) isFailoverError(message string) bool {
msg := strings.ToLower(strings.TrimSpace(message))
if msg == "" {
return false
}
return strings.Contains(msg, "upstream error:") && strings.Contains(msg, "failover")
}

View File

@@ -0,0 +1,721 @@
package service
import (
"context"
"fmt"
"log"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/robfig/cron/v3"
)
const (
opsScheduledReportJobName = "ops_scheduled_reports"
opsScheduledReportLeaderLockKeyDefault = "ops:scheduled_reports:leader"
opsScheduledReportLeaderLockTTLDefault = 5 * time.Minute
opsScheduledReportLastRunKeyPrefix = "ops:scheduled_reports:last_run:"
opsScheduledReportTickInterval = 1 * time.Minute
)
var opsScheduledReportCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
var opsScheduledReportReleaseScript = redis.NewScript(`
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
end
return 0
`)
type OpsScheduledReportService struct {
opsService *OpsService
userService *UserService
emailService *EmailService
redisClient *redis.Client
cfg *config.Config
instanceID string
loc *time.Location
distributedLockOn bool
warnNoRedisOnce sync.Once
startOnce sync.Once
stopOnce sync.Once
stopCtx context.Context
stop context.CancelFunc
wg sync.WaitGroup
}
func NewOpsScheduledReportService(
opsService *OpsService,
userService *UserService,
emailService *EmailService,
redisClient *redis.Client,
cfg *config.Config,
) *OpsScheduledReportService {
lockOn := cfg == nil || strings.TrimSpace(cfg.RunMode) != config.RunModeSimple
loc := time.Local
if cfg != nil && strings.TrimSpace(cfg.Timezone) != "" {
if parsed, err := time.LoadLocation(strings.TrimSpace(cfg.Timezone)); err == nil && parsed != nil {
loc = parsed
}
}
return &OpsScheduledReportService{
opsService: opsService,
userService: userService,
emailService: emailService,
redisClient: redisClient,
cfg: cfg,
instanceID: uuid.NewString(),
loc: loc,
distributedLockOn: lockOn,
warnNoRedisOnce: sync.Once{},
startOnce: sync.Once{},
stopOnce: sync.Once{},
stopCtx: nil,
stop: nil,
wg: sync.WaitGroup{},
}
}
func (s *OpsScheduledReportService) Start() {
s.StartWithContext(context.Background())
}
func (s *OpsScheduledReportService) StartWithContext(ctx context.Context) {
if s == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
if s.cfg != nil && !s.cfg.Ops.Enabled {
return
}
if s.opsService == nil || s.emailService == nil {
return
}
s.startOnce.Do(func() {
s.stopCtx, s.stop = context.WithCancel(ctx)
s.wg.Add(1)
go s.run()
})
}
func (s *OpsScheduledReportService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
if s.stop != nil {
s.stop()
}
})
s.wg.Wait()
}
func (s *OpsScheduledReportService) run() {
defer s.wg.Done()
ticker := time.NewTicker(opsScheduledReportTickInterval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCtx.Done():
return
}
}
}
func (s *OpsScheduledReportService) runOnce() {
if s == nil || s.opsService == nil || s.emailService == nil {
return
}
startedAt := time.Now().UTC()
runAt := startedAt
ctx, cancel := context.WithTimeout(s.stopCtx, 60*time.Second)
defer cancel()
// Respect ops monitoring enabled switch.
if !s.opsService.IsMonitoringEnabled(ctx) {
return
}
release, ok := s.tryAcquireLeaderLock(ctx)
if !ok {
return
}
if release != nil {
defer release()
}
now := time.Now()
if s.loc != nil {
now = now.In(s.loc)
}
reports := s.listScheduledReports(ctx, now)
if len(reports) == 0 {
return
}
reportsTotal := len(reports)
reportsDue := 0
sentAttempts := 0
for _, report := range reports {
if report == nil || !report.Enabled {
continue
}
if report.NextRunAt.After(now) {
continue
}
reportsDue++
attempts, err := s.runReport(ctx, report, now)
if err != nil {
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
return
}
sentAttempts += attempts
}
result := truncateString(fmt.Sprintf("reports=%d due=%d send_attempts=%d", reportsTotal, reportsDue, sentAttempts), 2048)
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
}
type opsScheduledReport struct {
Name string
ReportType string
Schedule string
Enabled bool
TimeRange time.Duration
Recipients []string
ErrorDigestMinCount int
AccountHealthErrorRateThreshold float64
LastRunAt *time.Time
NextRunAt time.Time
}
func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, now time.Time) []*opsScheduledReport {
if s == nil || s.opsService == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
if err != nil || emailCfg == nil {
return nil
}
if !emailCfg.Report.Enabled {
return nil
}
recipients := normalizeEmails(emailCfg.Report.Recipients)
type reportDef struct {
enabled bool
name string
kind string
timeRange time.Duration
schedule string
}
defs := []reportDef{
{enabled: emailCfg.Report.DailySummaryEnabled, name: "日报", kind: "daily_summary", timeRange: 24 * time.Hour, schedule: emailCfg.Report.DailySummarySchedule},
{enabled: emailCfg.Report.WeeklySummaryEnabled, name: "周报", kind: "weekly_summary", timeRange: 7 * 24 * time.Hour, schedule: emailCfg.Report.WeeklySummarySchedule},
{enabled: emailCfg.Report.ErrorDigestEnabled, name: "错误摘要", kind: "error_digest", timeRange: 24 * time.Hour, schedule: emailCfg.Report.ErrorDigestSchedule},
{enabled: emailCfg.Report.AccountHealthEnabled, name: "账号健康", kind: "account_health", timeRange: 24 * time.Hour, schedule: emailCfg.Report.AccountHealthSchedule},
}
out := make([]*opsScheduledReport, 0, len(defs))
for _, d := range defs {
if !d.enabled {
continue
}
spec := strings.TrimSpace(d.schedule)
if spec == "" {
continue
}
sched, err := opsScheduledReportCronParser.Parse(spec)
if err != nil {
log.Printf("[OpsScheduledReport] invalid cron spec=%q for report=%s: %v", spec, d.kind, err)
continue
}
lastRun := s.getLastRunAt(ctx, d.kind)
base := lastRun
if base.IsZero() {
// Allow a schedule matching the current minute to trigger right after startup.
base = now.Add(-1 * time.Minute)
}
next := sched.Next(base)
if next.IsZero() {
continue
}
var lastRunPtr *time.Time
if !lastRun.IsZero() {
lastCopy := lastRun
lastRunPtr = &lastCopy
}
out = append(out, &opsScheduledReport{
Name: d.name,
ReportType: d.kind,
Schedule: spec,
Enabled: true,
TimeRange: d.timeRange,
Recipients: recipients,
ErrorDigestMinCount: emailCfg.Report.ErrorDigestMinCount,
AccountHealthErrorRateThreshold: emailCfg.Report.AccountHealthErrorRateThreshold,
LastRunAt: lastRunPtr,
NextRunAt: next,
})
}
return out
}
func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) (int, error) {
if s == nil || s.opsService == nil || s.emailService == nil || report == nil {
return 0, nil
}
if ctx == nil {
ctx = context.Background()
}
// Mark as "run" up-front so a broken SMTP config doesn't spam retries every minute.
s.setLastRunAt(ctx, report.ReportType, now)
content, err := s.generateReportHTML(ctx, report, now)
if err != nil {
return 0, err
}
if strings.TrimSpace(content) == "" {
// Skip sending when the report decides not to emit content (e.g., digest below min count).
return 0, nil
}
recipients := report.Recipients
if len(recipients) == 0 && s.userService != nil {
admin, err := s.userService.GetFirstAdmin(ctx)
if err == nil && admin != nil && strings.TrimSpace(admin.Email) != "" {
recipients = []string{strings.TrimSpace(admin.Email)}
}
}
if len(recipients) == 0 {
return 0, nil
}
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
attempts := 0
for _, to := range recipients {
addr := strings.TrimSpace(to)
if addr == "" {
continue
}
attempts++
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
// Ignore per-recipient failures; continue best-effort.
continue
}
}
return attempts, nil
}
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
if s == nil || s.opsService == nil || report == nil {
return "", fmt.Errorf("service not initialized")
}
if report.TimeRange <= 0 {
return "", fmt.Errorf("invalid time range")
}
end := now.UTC()
start := end.Add(-report.TimeRange)
switch strings.TrimSpace(report.ReportType) {
case "daily_summary", "weekly_summary":
overview, err := s.opsService.GetDashboardOverview(ctx, &OpsDashboardFilter{
StartTime: start,
EndTime: end,
Platform: "",
GroupID: nil,
QueryMode: OpsQueryModeAuto,
})
if err != nil {
// If pre-aggregation isn't ready but the report is requested, fall back to raw.
if strings.TrimSpace(report.ReportType) == "daily_summary" || strings.TrimSpace(report.ReportType) == "weekly_summary" {
overview, err = s.opsService.GetDashboardOverview(ctx, &OpsDashboardFilter{
StartTime: start,
EndTime: end,
Platform: "",
GroupID: nil,
QueryMode: OpsQueryModeRaw,
})
}
if err != nil {
return "", err
}
}
return buildOpsSummaryEmailHTML(report.Name, start, end, overview), nil
case "error_digest":
// Lightweight digest: list recent errors (status>=400) and breakdown by type.
startTime := start
endTime := end
filter := &OpsErrorLogFilter{
StartTime: &startTime,
EndTime: &endTime,
Page: 1,
PageSize: 100,
}
out, err := s.opsService.GetErrorLogs(ctx, filter)
if err != nil {
return "", err
}
if report.ErrorDigestMinCount > 0 && out != nil && out.Total < report.ErrorDigestMinCount {
return "", nil
}
return buildOpsErrorDigestEmailHTML(report.Name, start, end, out), nil
case "account_health":
// Best-effort: use account availability (not error rate yet).
avail, err := s.opsService.GetAccountAvailability(ctx, "", nil)
if err != nil {
return "", err
}
_ = report.AccountHealthErrorRateThreshold // reserved for future per-account error rate report
return buildOpsAccountHealthEmailHTML(report.Name, start, end, avail), nil
default:
return "", fmt.Errorf("unknown report type: %s", report.ReportType)
}
}
func buildOpsSummaryEmailHTML(title string, start, end time.Time, overview *OpsDashboardOverview) string {
if overview == nil {
return fmt.Sprintf("<h2>%s</h2><p>No data.</p>", htmlEscape(title))
}
latP50 := "-"
latP99 := "-"
if overview.Duration.P50 != nil {
latP50 = fmt.Sprintf("%dms", *overview.Duration.P50)
}
if overview.Duration.P99 != nil {
latP99 = fmt.Sprintf("%dms", *overview.Duration.P99)
}
ttftP50 := "-"
ttftP99 := "-"
if overview.TTFT.P50 != nil {
ttftP50 = fmt.Sprintf("%dms", *overview.TTFT.P50)
}
if overview.TTFT.P99 != nil {
ttftP99 = fmt.Sprintf("%dms", *overview.TTFT.P99)
}
return fmt.Sprintf(`
<h2>%s</h2>
<p><b>Period</b>: %s ~ %s (UTC)</p>
<ul>
<li><b>Total Requests</b>: %d</li>
<li><b>Success</b>: %d</li>
<li><b>Errors (SLA)</b>: %d</li>
<li><b>Business Limited</b>: %d</li>
<li><b>SLA</b>: %.2f%%</li>
<li><b>Error Rate</b>: %.2f%%</li>
<li><b>Upstream Error Rate (excl 429/529)</b>: %.2f%%</li>
<li><b>Upstream Errors</b>: excl429/529=%d, 429=%d, 529=%d</li>
<li><b>Latency</b>: p50=%s, p99=%s</li>
<li><b>TTFT</b>: p50=%s, p99=%s</li>
<li><b>Tokens</b>: %d</li>
<li><b>QPS</b>: current=%.1f, peak=%.1f, avg=%.1f</li>
<li><b>TPS</b>: current=%.1f, peak=%.1f, avg=%.1f</li>
</ul>
`,
htmlEscape(strings.TrimSpace(title)),
htmlEscape(start.UTC().Format(time.RFC3339)),
htmlEscape(end.UTC().Format(time.RFC3339)),
overview.RequestCountTotal,
overview.SuccessCount,
overview.ErrorCountSLA,
overview.BusinessLimitedCount,
overview.SLA*100,
overview.ErrorRate*100,
overview.UpstreamErrorRate*100,
overview.UpstreamErrorCountExcl429529,
overview.Upstream429Count,
overview.Upstream529Count,
htmlEscape(latP50),
htmlEscape(latP99),
htmlEscape(ttftP50),
htmlEscape(ttftP99),
overview.TokenConsumed,
overview.QPS.Current,
overview.QPS.Peak,
overview.QPS.Avg,
overview.TPS.Current,
overview.TPS.Peak,
overview.TPS.Avg,
)
}
func buildOpsErrorDigestEmailHTML(title string, start, end time.Time, list *OpsErrorLogList) string {
total := 0
recent := []*OpsErrorLog{}
if list != nil {
total = list.Total
recent = list.Errors
}
if len(recent) > 10 {
recent = recent[:10]
}
rows := ""
for _, item := range recent {
if item == nil {
continue
}
rows += fmt.Sprintf(
"<tr><td>%s</td><td>%s</td><td>%d</td><td>%s</td></tr>",
htmlEscape(item.CreatedAt.UTC().Format(time.RFC3339)),
htmlEscape(item.Platform),
item.StatusCode,
htmlEscape(truncateString(item.Message, 180)),
)
}
if rows == "" {
rows = "<tr><td colspan=\"4\">No recent errors.</td></tr>"
}
return fmt.Sprintf(`
<h2>%s</h2>
<p><b>Period</b>: %s ~ %s (UTC)</p>
<p><b>Total Errors</b>: %d</p>
<h3>Recent</h3>
<table border="1" cellpadding="6" cellspacing="0" style="border-collapse:collapse;">
<thead><tr><th>Time</th><th>Platform</th><th>Status</th><th>Message</th></tr></thead>
<tbody>%s</tbody>
</table>
`,
htmlEscape(strings.TrimSpace(title)),
htmlEscape(start.UTC().Format(time.RFC3339)),
htmlEscape(end.UTC().Format(time.RFC3339)),
total,
rows,
)
}
func buildOpsAccountHealthEmailHTML(title string, start, end time.Time, avail *OpsAccountAvailability) string {
total := 0
available := 0
rateLimited := 0
hasError := 0
if avail != nil && avail.Accounts != nil {
for _, a := range avail.Accounts {
if a == nil {
continue
}
total++
if a.IsAvailable {
available++
}
if a.IsRateLimited {
rateLimited++
}
if a.HasError {
hasError++
}
}
}
return fmt.Sprintf(`
<h2>%s</h2>
<p><b>Period</b>: %s ~ %s (UTC)</p>
<ul>
<li><b>Total Accounts</b>: %d</li>
<li><b>Available</b>: %d</li>
<li><b>Rate Limited</b>: %d</li>
<li><b>Error</b>: %d</li>
</ul>
<p>Note: This report currently reflects account availability status only.</p>
`,
htmlEscape(strings.TrimSpace(title)),
htmlEscape(start.UTC().Format(time.RFC3339)),
htmlEscape(end.UTC().Format(time.RFC3339)),
total,
available,
rateLimited,
hasError,
)
}
func (s *OpsScheduledReportService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if s == nil || !s.distributedLockOn {
return nil, true
}
if s.redisClient == nil {
s.warnNoRedisOnce.Do(func() {
log.Printf("[OpsScheduledReport] redis not configured; running without distributed lock")
})
return nil, true
}
if ctx == nil {
ctx = context.Background()
}
key := opsScheduledReportLeaderLockKeyDefault
ttl := opsScheduledReportLeaderLockTTLDefault
if strings.TrimSpace(key) == "" {
key = "ops:scheduled_reports:leader"
}
if ttl <= 0 {
ttl = 5 * time.Minute
}
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
if err != nil {
// Prefer fail-closed to avoid duplicate report sends when Redis is flaky.
log.Printf("[OpsScheduledReport] leader lock SetNX failed; skipping this cycle: %v", err)
return nil, false
}
if !ok {
return nil, false
}
return func() {
_, _ = opsScheduledReportReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
}, true
}
func (s *OpsScheduledReportService) getLastRunAt(ctx context.Context, reportType string) time.Time {
if s == nil || s.redisClient == nil {
return time.Time{}
}
kind := strings.TrimSpace(reportType)
if kind == "" {
return time.Time{}
}
key := opsScheduledReportLastRunKeyPrefix + kind
raw, err := s.redisClient.Get(ctx, key).Result()
if err != nil || strings.TrimSpace(raw) == "" {
return time.Time{}
}
sec, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
if err != nil || sec <= 0 {
return time.Time{}
}
last := time.Unix(sec, 0)
// Cron schedules are interpreted in the configured timezone (s.loc). Ensure the base time
// passed into cron.Next() uses the same location; otherwise the job will drift by timezone
// offset (e.g. Asia/Shanghai default would run 8h later after the first execution).
if s.loc != nil {
return last.In(s.loc)
}
return last.UTC()
}
func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType string, t time.Time) {
if s == nil || s.redisClient == nil {
return
}
kind := strings.TrimSpace(reportType)
if kind == "" {
return
}
if t.IsZero() {
t = time.Now().UTC()
}
key := opsScheduledReportLastRunKeyPrefix + kind
_ = s.redisClient.Set(ctx, key, strconv.FormatInt(t.UTC().Unix(), 10), 14*24*time.Hour).Err()
}
func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg := strings.TrimSpace(result)
if msg == "" {
msg = "ok"
}
msg = truncateString(msg, 2048)
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsScheduledReportJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
LastResult: &msg,
})
}
func (s *OpsScheduledReportService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil || err == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
msg := truncateString(err.Error(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsScheduledReportJobName,
LastRunAt: &runAt,
LastErrorAt: &now,
LastError: &msg,
LastDurationMs: &durMs,
})
}
func normalizeEmails(in []string) []string {
if len(in) == 0 {
return nil
}
seen := make(map[string]struct{}, len(in))
out := make([]string, 0, len(in))
for _, raw := range in {
addr := strings.ToLower(strings.TrimSpace(raw))
if addr == "" {
continue
}
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
return out
}

View File

@@ -0,0 +1,613 @@
package service
import (
"context"
"database/sql"
"encoding/json"
"errors"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
const (
opsMaxStoredRequestBodyBytes = 10 * 1024
opsMaxStoredErrorBodyBytes = 20 * 1024
)
// OpsService provides ingestion and query APIs for the Ops monitoring module.
type OpsService struct {
opsRepo OpsRepository
settingRepo SettingRepository
cfg *config.Config
accountRepo AccountRepository
// getAccountAvailability is a unit-test hook for overriding account availability lookup.
getAccountAvailability func(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error)
concurrencyService *ConcurrencyService
gatewayService *GatewayService
openAIGatewayService *OpenAIGatewayService
geminiCompatService *GeminiMessagesCompatService
antigravityGatewayService *AntigravityGatewayService
}
func NewOpsService(
opsRepo OpsRepository,
settingRepo SettingRepository,
cfg *config.Config,
accountRepo AccountRepository,
concurrencyService *ConcurrencyService,
gatewayService *GatewayService,
openAIGatewayService *OpenAIGatewayService,
geminiCompatService *GeminiMessagesCompatService,
antigravityGatewayService *AntigravityGatewayService,
) *OpsService {
return &OpsService{
opsRepo: opsRepo,
settingRepo: settingRepo,
cfg: cfg,
accountRepo: accountRepo,
concurrencyService: concurrencyService,
gatewayService: gatewayService,
openAIGatewayService: openAIGatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
}
}
func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error {
if s.IsMonitoringEnabled(ctx) {
return nil
}
return ErrOpsDisabled
}
func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
// Hard switch: disable ops entirely.
if s.cfg != nil && !s.cfg.Ops.Enabled {
return false
}
if s.settingRepo == nil {
return true
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
if err != nil {
// Default enabled when key is missing, and fail-open on transient errors
// (ops should never block gateway traffic).
if errors.Is(err, ErrSettingNotFound) {
return true
}
return true
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
return false
default:
return true
}
}
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
if entry == nil {
return nil
}
if !s.IsMonitoringEnabled(ctx) {
return nil
}
if s.opsRepo == nil {
return nil
}
// Ensure timestamps are always populated.
if entry.CreatedAt.IsZero() {
entry.CreatedAt = time.Now()
}
// Ensure required fields exist (DB has NOT NULL constraints).
entry.ErrorPhase = strings.TrimSpace(entry.ErrorPhase)
entry.ErrorType = strings.TrimSpace(entry.ErrorType)
if entry.ErrorPhase == "" {
entry.ErrorPhase = "internal"
}
if entry.ErrorType == "" {
entry.ErrorType = "api_error"
}
// Sanitize + trim request body (errors only).
if len(rawRequestBody) > 0 {
sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes)
if sanitized != "" {
entry.RequestBodyJSON = &sanitized
}
entry.RequestBodyTruncated = truncated
entry.RequestBodyBytes = &bytesLen
}
// Sanitize + truncate error_body to avoid storing sensitive data.
if strings.TrimSpace(entry.ErrorBody) != "" {
sanitized, _ := sanitizeErrorBodyForStorage(entry.ErrorBody, opsMaxStoredErrorBodyBytes)
entry.ErrorBody = sanitized
}
// Sanitize upstream error context if provided by gateway services.
if entry.UpstreamStatusCode != nil && *entry.UpstreamStatusCode <= 0 {
entry.UpstreamStatusCode = nil
}
if entry.UpstreamErrorMessage != nil {
msg := strings.TrimSpace(*entry.UpstreamErrorMessage)
msg = sanitizeUpstreamErrorMessage(msg)
msg = truncateString(msg, 2048)
if strings.TrimSpace(msg) == "" {
entry.UpstreamErrorMessage = nil
} else {
entry.UpstreamErrorMessage = &msg
}
}
if entry.UpstreamErrorDetail != nil {
detail := strings.TrimSpace(*entry.UpstreamErrorDetail)
if detail == "" {
entry.UpstreamErrorDetail = nil
} else {
sanitized, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
if strings.TrimSpace(sanitized) == "" {
entry.UpstreamErrorDetail = nil
} else {
entry.UpstreamErrorDetail = &sanitized
}
}
}
// Sanitize + serialize upstream error events list.
if len(entry.UpstreamErrors) > 0 {
const maxEvents = 32
events := entry.UpstreamErrors
if len(events) > maxEvents {
events = events[len(events)-maxEvents:]
}
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
for _, ev := range events {
if ev == nil {
continue
}
out := *ev
out.Platform = strings.TrimSpace(out.Platform)
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
if out.AccountID < 0 {
out.AccountID = 0
}
if out.UpstreamStatusCode < 0 {
out.UpstreamStatusCode = 0
}
if out.AtUnixMs < 0 {
out.AtUnixMs = 0
}
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
msg = truncateString(msg, 2048)
out.Message = msg
detail := strings.TrimSpace(out.Detail)
if detail != "" {
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
out.Detail = sanitizedDetail
} else {
out.Detail = ""
}
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitized != "" {
out.UpstreamRequestBody = sanitized
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
out.Kind = out.Kind + ":request_body_truncated"
}
} else {
out.UpstreamRequestBody = ""
}
}
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
}
evCopy := out
sanitized = append(sanitized, &evCopy)
}
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
entry.UpstreamErrors = nil
}
if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil {
// Never bubble up to gateway; best-effort logging.
log.Printf("[Ops] RecordError failed: %v", err)
return err
}
return nil
}
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
}
result, err := s.opsRepo.ListErrorLogs(ctx, filter)
if err != nil {
log.Printf("[Ops] GetErrorLogs failed: %v", err)
return nil, err
}
return result, nil
}
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
detail, err := s.opsRepo.GetErrorLogByID(ctx, id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
return nil, infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
return detail, nil
}
func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []*OpsRetryAttempt{}, nil
}
return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
}
return items, nil
}
func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
// Best-effort ensure the error exists
if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
}
func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
bytesLen = len(raw)
if len(raw) == 0 {
return "", false, 0
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
// If it's not valid JSON, don't store (retry would not be reliable anyway).
return "", false, bytesLen
}
decoded = redactSensitiveJSON(decoded)
encoded, err := json.Marshal(decoded)
if err != nil {
return "", false, bytesLen
}
if len(encoded) <= maxBytes {
return string(encoded), false, bytesLen
}
// Trim conversation history to keep the most recent context.
if root, ok := decoded.(map[string]any); ok {
if trimmed, ok := trimConversationArrays(root, maxBytes); ok {
encoded2, err2 := json.Marshal(trimmed)
if err2 == nil && len(encoded2) <= maxBytes {
return string(encoded2), true, bytesLen
}
// Fallthrough: keep shrinking.
decoded = trimmed
}
essential := shrinkToEssentials(root)
encoded3, err3 := json.Marshal(essential)
if err3 == nil && len(encoded3) <= maxBytes {
return string(encoded3), true, bytesLen
}
}
// Last resort: keep JSON shape but drop big fields.
// This avoids downstream code that expects certain top-level keys from crashing.
if root, ok := decoded.(map[string]any); ok {
placeholder := shallowCopyMap(root)
placeholder["request_body_truncated"] = true
// Replace potentially huge arrays/strings, but keep the keys present.
for _, k := range []string{"messages", "contents", "input", "prompt"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = []any{}
}
}
for _, k := range []string{"text"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = ""
}
}
encoded4, err4 := json.Marshal(placeholder)
if err4 == nil {
if len(encoded4) <= maxBytes {
return string(encoded4), true, bytesLen
}
}
}
// Final fallback: minimal valid JSON.
encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
if err4 != nil {
return "", true, bytesLen
}
return string(encoded4), true, bytesLen
}
func redactSensitiveJSON(v any) any {
switch t := v.(type) {
case map[string]any:
out := make(map[string]any, len(t))
for k, vv := range t {
if isSensitiveKey(k) {
out[k] = "[REDACTED]"
continue
}
out[k] = redactSensitiveJSON(vv)
}
return out
case []any:
out := make([]any, 0, len(t))
for _, vv := range t {
out = append(out, redactSensitiveJSON(vv))
}
return out
default:
return v
}
}
func isSensitiveKey(key string) bool {
k := strings.ToLower(strings.TrimSpace(key))
if k == "" {
return false
}
// Exact matches (common credential fields).
switch k {
case "authorization",
"proxy-authorization",
"x-api-key",
"api_key",
"apikey",
"access_token",
"refresh_token",
"id_token",
"session_token",
"token",
"password",
"passwd",
"passphrase",
"secret",
"client_secret",
"private_key",
"jwt",
"signature",
"accesskeyid",
"secretaccesskey":
return true
}
// Suffix matches.
for _, suffix := range []string{
"_secret",
"_token",
"_id_token",
"_session_token",
"_password",
"_passwd",
"_passphrase",
"_key",
"secret_key",
"private_key",
} {
if strings.HasSuffix(k, suffix) {
return true
}
}
// Substring matches (conservative, but errs on the side of privacy).
for _, sub := range []string{
"secret",
"token",
"password",
"passwd",
"passphrase",
"privatekey",
"private_key",
"apikey",
"api_key",
"accesskeyid",
"secretaccesskey",
"bearer",
"cookie",
"credential",
"session",
"jwt",
"signature",
} {
if strings.Contains(k, sub) {
return true
}
}
return false
}
func trimConversationArrays(root map[string]any, maxBytes int) (map[string]any, bool) {
// Supported: anthropic/openai: messages; gemini: contents.
if out, ok := trimArrayField(root, "messages", maxBytes); ok {
return out, true
}
if out, ok := trimArrayField(root, "contents", maxBytes); ok {
return out, true
}
return root, false
}
func trimArrayField(root map[string]any, field string, maxBytes int) (map[string]any, bool) {
raw, ok := root[field]
if !ok {
return nil, false
}
arr, ok := raw.([]any)
if !ok || len(arr) == 0 {
return nil, false
}
// Keep at least the last message/content. Use binary search so we don't marshal O(n) times.
// We are dropping from the *front* of the array (oldest context first).
lo := 0
hi := len(arr) - 1 // inclusive; hi ensures at least one item remains
var best map[string]any
found := false
for lo <= hi {
mid := (lo + hi) / 2
candidateArr := arr[mid:]
if len(candidateArr) == 0 {
lo = mid + 1
continue
}
next := shallowCopyMap(root)
next[field] = candidateArr
encoded, err := json.Marshal(next)
if err != nil {
// If marshal fails, try dropping more.
lo = mid + 1
continue
}
if len(encoded) <= maxBytes {
best = next
found = true
// Try to keep more context by dropping fewer items.
hi = mid - 1
continue
}
// Need to drop more.
lo = mid + 1
}
if found {
return best, true
}
// Nothing fit (even with only one element); return the smallest slice and let the
// caller fall back to shrinkToEssentials().
next := shallowCopyMap(root)
next[field] = arr[len(arr)-1:]
return next, true
}
func shrinkToEssentials(root map[string]any) map[string]any {
out := make(map[string]any)
for _, key := range []string{"model", "stream", "max_tokens", "temperature", "top_p", "top_k"} {
if v, ok := root[key]; ok {
out[key] = v
}
}
// Keep only the last element of the conversation array.
if v, ok := root["messages"]; ok {
if arr, ok := v.([]any); ok && len(arr) > 0 {
out["messages"] = []any{arr[len(arr)-1]}
}
}
if v, ok := root["contents"]; ok {
if arr, ok := v.([]any); ok && len(arr) > 0 {
out["contents"] = []any{arr[len(arr)-1]}
}
}
return out
}
func shallowCopyMap(m map[string]any) map[string]any {
out := make(map[string]any, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, truncated bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", false
}
// Prefer JSON-safe sanitization when possible.
if out, trunc, _ := sanitizeAndTrimRequestBody([]byte(raw), maxBytes); out != "" {
return out, trunc
}
// Non-JSON: best-effort truncate.
if maxBytes > 0 && len(raw) > maxBytes {
return truncateString(raw, maxBytes), true
}
return raw, false
}

View File

@@ -0,0 +1,562 @@
package service
import (
"context"
"encoding/json"
"errors"
"strings"
"time"
)
const (
opsAlertEvaluatorLeaderLockKeyDefault = "ops:alert:evaluator:leader"
opsAlertEvaluatorLeaderLockTTLDefault = 30 * time.Second
)
// =========================
// Email notification config
// =========================
func (s *OpsService) GetEmailNotificationConfig(ctx context.Context) (*OpsEmailNotificationConfig, error) {
defaultCfg := defaultOpsEmailNotificationConfig()
if s == nil || s.settingRepo == nil {
return defaultCfg, nil
}
if ctx == nil {
ctx = context.Background()
}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsEmailNotificationConfig)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
// Initialize defaults on first read (best-effort).
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
_ = s.settingRepo.Set(ctx, SettingKeyOpsEmailNotificationConfig, string(b))
}
return defaultCfg, nil
}
return nil, err
}
cfg := &OpsEmailNotificationConfig{}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
// Corrupted JSON should not break ops UI; fall back to defaults.
return defaultCfg, nil
}
normalizeOpsEmailNotificationConfig(cfg)
return cfg, nil
}
func (s *OpsService) UpdateEmailNotificationConfig(ctx context.Context, req *OpsEmailNotificationConfigUpdateRequest) (*OpsEmailNotificationConfig, error) {
if s == nil || s.settingRepo == nil {
return nil, errors.New("setting repository not initialized")
}
if ctx == nil {
ctx = context.Background()
}
if req == nil {
return nil, errors.New("invalid request")
}
cfg, err := s.GetEmailNotificationConfig(ctx)
if err != nil {
return nil, err
}
if req.Alert != nil {
cfg.Alert.Enabled = req.Alert.Enabled
if req.Alert.Recipients != nil {
cfg.Alert.Recipients = req.Alert.Recipients
}
cfg.Alert.MinSeverity = strings.TrimSpace(req.Alert.MinSeverity)
cfg.Alert.RateLimitPerHour = req.Alert.RateLimitPerHour
cfg.Alert.BatchingWindowSeconds = req.Alert.BatchingWindowSeconds
cfg.Alert.IncludeResolvedAlerts = req.Alert.IncludeResolvedAlerts
}
if req.Report != nil {
cfg.Report.Enabled = req.Report.Enabled
if req.Report.Recipients != nil {
cfg.Report.Recipients = req.Report.Recipients
}
cfg.Report.DailySummaryEnabled = req.Report.DailySummaryEnabled
cfg.Report.DailySummarySchedule = strings.TrimSpace(req.Report.DailySummarySchedule)
cfg.Report.WeeklySummaryEnabled = req.Report.WeeklySummaryEnabled
cfg.Report.WeeklySummarySchedule = strings.TrimSpace(req.Report.WeeklySummarySchedule)
cfg.Report.ErrorDigestEnabled = req.Report.ErrorDigestEnabled
cfg.Report.ErrorDigestSchedule = strings.TrimSpace(req.Report.ErrorDigestSchedule)
cfg.Report.ErrorDigestMinCount = req.Report.ErrorDigestMinCount
cfg.Report.AccountHealthEnabled = req.Report.AccountHealthEnabled
cfg.Report.AccountHealthSchedule = strings.TrimSpace(req.Report.AccountHealthSchedule)
cfg.Report.AccountHealthErrorRateThreshold = req.Report.AccountHealthErrorRateThreshold
}
if err := validateOpsEmailNotificationConfig(cfg); err != nil {
return nil, err
}
normalizeOpsEmailNotificationConfig(cfg)
raw, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
if err := s.settingRepo.Set(ctx, SettingKeyOpsEmailNotificationConfig, string(raw)); err != nil {
return nil, err
}
return cfg, nil
}
func defaultOpsEmailNotificationConfig() *OpsEmailNotificationConfig {
return &OpsEmailNotificationConfig{
Alert: OpsEmailAlertConfig{
Enabled: true,
Recipients: []string{},
MinSeverity: "",
RateLimitPerHour: 0,
BatchingWindowSeconds: 0,
IncludeResolvedAlerts: false,
},
Report: OpsEmailReportConfig{
Enabled: false,
Recipients: []string{},
DailySummaryEnabled: false,
DailySummarySchedule: "0 9 * * *",
WeeklySummaryEnabled: false,
WeeklySummarySchedule: "0 9 * * 1",
ErrorDigestEnabled: false,
ErrorDigestSchedule: "0 9 * * *",
ErrorDigestMinCount: 10,
AccountHealthEnabled: false,
AccountHealthSchedule: "0 9 * * *",
AccountHealthErrorRateThreshold: 10.0,
},
}
}
func normalizeOpsEmailNotificationConfig(cfg *OpsEmailNotificationConfig) {
if cfg == nil {
return
}
if cfg.Alert.Recipients == nil {
cfg.Alert.Recipients = []string{}
}
if cfg.Report.Recipients == nil {
cfg.Report.Recipients = []string{}
}
cfg.Alert.MinSeverity = strings.TrimSpace(cfg.Alert.MinSeverity)
cfg.Report.DailySummarySchedule = strings.TrimSpace(cfg.Report.DailySummarySchedule)
cfg.Report.WeeklySummarySchedule = strings.TrimSpace(cfg.Report.WeeklySummarySchedule)
cfg.Report.ErrorDigestSchedule = strings.TrimSpace(cfg.Report.ErrorDigestSchedule)
cfg.Report.AccountHealthSchedule = strings.TrimSpace(cfg.Report.AccountHealthSchedule)
// Fill missing schedules with defaults to avoid breaking cron logic if clients send empty strings.
if cfg.Report.DailySummarySchedule == "" {
cfg.Report.DailySummarySchedule = "0 9 * * *"
}
if cfg.Report.WeeklySummarySchedule == "" {
cfg.Report.WeeklySummarySchedule = "0 9 * * 1"
}
if cfg.Report.ErrorDigestSchedule == "" {
cfg.Report.ErrorDigestSchedule = "0 9 * * *"
}
if cfg.Report.AccountHealthSchedule == "" {
cfg.Report.AccountHealthSchedule = "0 9 * * *"
}
}
func validateOpsEmailNotificationConfig(cfg *OpsEmailNotificationConfig) error {
if cfg == nil {
return errors.New("invalid config")
}
if cfg.Alert.RateLimitPerHour < 0 {
return errors.New("alert.rate_limit_per_hour must be >= 0")
}
if cfg.Alert.BatchingWindowSeconds < 0 {
return errors.New("alert.batching_window_seconds must be >= 0")
}
switch strings.TrimSpace(cfg.Alert.MinSeverity) {
case "", "critical", "warning", "info":
default:
return errors.New("alert.min_severity must be one of: critical, warning, info, or empty")
}
if cfg.Report.ErrorDigestMinCount < 0 {
return errors.New("report.error_digest_min_count must be >= 0")
}
if cfg.Report.AccountHealthErrorRateThreshold < 0 || cfg.Report.AccountHealthErrorRateThreshold > 100 {
return errors.New("report.account_health_error_rate_threshold must be between 0 and 100")
}
return nil
}
// =========================
// Alert runtime settings
// =========================
func defaultOpsAlertRuntimeSettings() *OpsAlertRuntimeSettings {
return &OpsAlertRuntimeSettings{
EvaluationIntervalSeconds: 60,
DistributedLock: OpsDistributedLockSettings{
Enabled: true,
Key: opsAlertEvaluatorLeaderLockKeyDefault,
TTLSeconds: int(opsAlertEvaluatorLeaderLockTTLDefault.Seconds()),
},
Silencing: OpsAlertSilencingSettings{
Enabled: false,
GlobalUntilRFC3339: "",
GlobalReason: "",
Entries: []OpsAlertSilenceEntry{},
},
}
}
func normalizeOpsDistributedLockSettings(s *OpsDistributedLockSettings, defaultKey string, defaultTTLSeconds int) {
if s == nil {
return
}
s.Key = strings.TrimSpace(s.Key)
if s.Key == "" {
s.Key = defaultKey
}
if s.TTLSeconds <= 0 {
s.TTLSeconds = defaultTTLSeconds
}
}
func normalizeOpsAlertSilencingSettings(s *OpsAlertSilencingSettings) {
if s == nil {
return
}
s.GlobalUntilRFC3339 = strings.TrimSpace(s.GlobalUntilRFC3339)
s.GlobalReason = strings.TrimSpace(s.GlobalReason)
if s.Entries == nil {
s.Entries = []OpsAlertSilenceEntry{}
}
for i := range s.Entries {
s.Entries[i].UntilRFC3339 = strings.TrimSpace(s.Entries[i].UntilRFC3339)
s.Entries[i].Reason = strings.TrimSpace(s.Entries[i].Reason)
}
}
func validateOpsDistributedLockSettings(s OpsDistributedLockSettings) error {
if strings.TrimSpace(s.Key) == "" {
return errors.New("distributed_lock.key is required")
}
if s.TTLSeconds <= 0 || s.TTLSeconds > int((24*time.Hour).Seconds()) {
return errors.New("distributed_lock.ttl_seconds must be between 1 and 86400")
}
return nil
}
func validateOpsAlertSilencingSettings(s OpsAlertSilencingSettings) error {
parse := func(raw string) error {
if strings.TrimSpace(raw) == "" {
return nil
}
if _, err := time.Parse(time.RFC3339, raw); err != nil {
return errors.New("silencing time must be RFC3339")
}
return nil
}
if err := parse(s.GlobalUntilRFC3339); err != nil {
return err
}
for _, entry := range s.Entries {
if strings.TrimSpace(entry.UntilRFC3339) == "" {
return errors.New("silencing.entries.until_rfc3339 is required")
}
if _, err := time.Parse(time.RFC3339, entry.UntilRFC3339); err != nil {
return errors.New("silencing.entries.until_rfc3339 must be RFC3339")
}
}
return nil
}
func (s *OpsService) GetOpsAlertRuntimeSettings(ctx context.Context) (*OpsAlertRuntimeSettings, error) {
defaultCfg := defaultOpsAlertRuntimeSettings()
if s == nil || s.settingRepo == nil {
return defaultCfg, nil
}
if ctx == nil {
ctx = context.Background()
}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsAlertRuntimeSettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
_ = s.settingRepo.Set(ctx, SettingKeyOpsAlertRuntimeSettings, string(b))
}
return defaultCfg, nil
}
return nil, err
}
cfg := &OpsAlertRuntimeSettings{}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return defaultCfg, nil
}
if cfg.EvaluationIntervalSeconds <= 0 {
cfg.EvaluationIntervalSeconds = defaultCfg.EvaluationIntervalSeconds
}
normalizeOpsDistributedLockSettings(&cfg.DistributedLock, opsAlertEvaluatorLeaderLockKeyDefault, defaultCfg.DistributedLock.TTLSeconds)
normalizeOpsAlertSilencingSettings(&cfg.Silencing)
return cfg, nil
}
func (s *OpsService) UpdateOpsAlertRuntimeSettings(ctx context.Context, cfg *OpsAlertRuntimeSettings) (*OpsAlertRuntimeSettings, error) {
if s == nil || s.settingRepo == nil {
return nil, errors.New("setting repository not initialized")
}
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return nil, errors.New("invalid config")
}
if cfg.EvaluationIntervalSeconds < 1 || cfg.EvaluationIntervalSeconds > int((24*time.Hour).Seconds()) {
return nil, errors.New("evaluation_interval_seconds must be between 1 and 86400")
}
if cfg.DistributedLock.Enabled {
if err := validateOpsDistributedLockSettings(cfg.DistributedLock); err != nil {
return nil, err
}
}
if cfg.Silencing.Enabled {
if err := validateOpsAlertSilencingSettings(cfg.Silencing); err != nil {
return nil, err
}
}
defaultCfg := defaultOpsAlertRuntimeSettings()
normalizeOpsDistributedLockSettings(&cfg.DistributedLock, opsAlertEvaluatorLeaderLockKeyDefault, defaultCfg.DistributedLock.TTLSeconds)
normalizeOpsAlertSilencingSettings(&cfg.Silencing)
raw, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
if err := s.settingRepo.Set(ctx, SettingKeyOpsAlertRuntimeSettings, string(raw)); err != nil {
return nil, err
}
// Return a fresh copy (avoid callers holding pointers into internal slices that may be mutated).
updated := &OpsAlertRuntimeSettings{}
_ = json.Unmarshal(raw, updated)
return updated, nil
}
// =========================
// Advanced settings
// =========================
func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
return &OpsAdvancedSettings{
DataRetention: OpsDataRetentionSettings{
CleanupEnabled: false,
CleanupSchedule: "0 2 * * *",
ErrorLogRetentionDays: 30,
MinuteMetricsRetentionDays: 30,
HourlyMetricsRetentionDays: 30,
},
Aggregation: OpsAggregationSettings{
AggregationEnabled: false,
},
IgnoreCountTokensErrors: false,
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
}
}
func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
if cfg == nil {
return
}
cfg.DataRetention.CleanupSchedule = strings.TrimSpace(cfg.DataRetention.CleanupSchedule)
if cfg.DataRetention.CleanupSchedule == "" {
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
}
if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
cfg.DataRetention.ErrorLogRetentionDays = 30
}
if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
cfg.DataRetention.MinuteMetricsRetentionDays = 30
}
if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
cfg.DataRetention.HourlyMetricsRetentionDays = 30
}
// Normalize auto refresh interval (default 30 seconds)
if cfg.AutoRefreshIntervalSec <= 0 {
cfg.AutoRefreshIntervalSec = 30
}
}
func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
if cfg == nil {
return errors.New("invalid config")
}
if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
return errors.New("error_log_retention_days must be between 1 and 365")
}
if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
return errors.New("minute_metrics_retention_days must be between 1 and 365")
}
if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
return errors.New("hourly_metrics_retention_days must be between 1 and 365")
}
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
}
return nil
}
func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSettings, error) {
defaultCfg := defaultOpsAdvancedSettings()
if s == nil || s.settingRepo == nil {
return defaultCfg, nil
}
if ctx == nil {
ctx = context.Background()
}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsAdvancedSettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
_ = s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(b))
}
return defaultCfg, nil
}
return nil, err
}
cfg := &OpsAdvancedSettings{}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return defaultCfg, nil
}
normalizeOpsAdvancedSettings(cfg)
return cfg, nil
}
func (s *OpsService) UpdateOpsAdvancedSettings(ctx context.Context, cfg *OpsAdvancedSettings) (*OpsAdvancedSettings, error) {
if s == nil || s.settingRepo == nil {
return nil, errors.New("setting repository not initialized")
}
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return nil, errors.New("invalid config")
}
if err := validateOpsAdvancedSettings(cfg); err != nil {
return nil, err
}
normalizeOpsAdvancedSettings(cfg)
raw, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
if err := s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(raw)); err != nil {
return nil, err
}
updated := &OpsAdvancedSettings{}
_ = json.Unmarshal(raw, updated)
return updated, nil
}
// =========================
// Metric thresholds
// =========================
const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
func defaultOpsMetricThresholds() *OpsMetricThresholds {
slaMin := 99.5
ttftMax := 500.0
reqErrMax := 5.0
upstreamErrMax := 5.0
return &OpsMetricThresholds{
SLAPercentMin: &slaMin,
TTFTp99MsMax: &ttftMax,
RequestErrorRatePercentMax: &reqErrMax,
UpstreamErrorRatePercentMax: &upstreamErrMax,
}
}
func (s *OpsService) GetMetricThresholds(ctx context.Context) (*OpsMetricThresholds, error) {
defaultCfg := defaultOpsMetricThresholds()
if s == nil || s.settingRepo == nil {
return defaultCfg, nil
}
if ctx == nil {
ctx = context.Background()
}
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMetricThresholds)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
_ = s.settingRepo.Set(ctx, SettingKeyOpsMetricThresholds, string(b))
}
return defaultCfg, nil
}
return nil, err
}
cfg := &OpsMetricThresholds{}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
return defaultCfg, nil
}
return cfg, nil
}
func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricThresholds) (*OpsMetricThresholds, error) {
if s == nil || s.settingRepo == nil {
return nil, errors.New("setting repository not initialized")
}
if ctx == nil {
ctx = context.Background()
}
if cfg == nil {
return nil, errors.New("invalid config")
}
// Validate thresholds
if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) {
return nil, errors.New("sla_percent_min must be between 0 and 100")
}
if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 {
return nil, errors.New("ttft_p99_ms_max must be >= 0")
}
if cfg.RequestErrorRatePercentMax != nil && (*cfg.RequestErrorRatePercentMax < 0 || *cfg.RequestErrorRatePercentMax > 100) {
return nil, errors.New("request_error_rate_percent_max must be between 0 and 100")
}
if cfg.UpstreamErrorRatePercentMax != nil && (*cfg.UpstreamErrorRatePercentMax < 0 || *cfg.UpstreamErrorRatePercentMax > 100) {
return nil, errors.New("upstream_error_rate_percent_max must be between 0 and 100")
}
raw, err := json.Marshal(cfg)
if err != nil {
return nil, err
}
if err := s.settingRepo.Set(ctx, SettingKeyOpsMetricThresholds, string(raw)); err != nil {
return nil, err
}
updated := &OpsMetricThresholds{}
_ = json.Unmarshal(raw, updated)
return updated, nil
}

View File

@@ -0,0 +1,100 @@
package service
// Ops settings models stored in DB `settings` table (JSON blobs).
type OpsEmailNotificationConfig struct {
Alert OpsEmailAlertConfig `json:"alert"`
Report OpsEmailReportConfig `json:"report"`
}
type OpsEmailAlertConfig struct {
Enabled bool `json:"enabled"`
Recipients []string `json:"recipients"`
MinSeverity string `json:"min_severity"`
RateLimitPerHour int `json:"rate_limit_per_hour"`
BatchingWindowSeconds int `json:"batching_window_seconds"`
IncludeResolvedAlerts bool `json:"include_resolved_alerts"`
}
type OpsEmailReportConfig struct {
Enabled bool `json:"enabled"`
Recipients []string `json:"recipients"`
DailySummaryEnabled bool `json:"daily_summary_enabled"`
DailySummarySchedule string `json:"daily_summary_schedule"`
WeeklySummaryEnabled bool `json:"weekly_summary_enabled"`
WeeklySummarySchedule string `json:"weekly_summary_schedule"`
ErrorDigestEnabled bool `json:"error_digest_enabled"`
ErrorDigestSchedule string `json:"error_digest_schedule"`
ErrorDigestMinCount int `json:"error_digest_min_count"`
AccountHealthEnabled bool `json:"account_health_enabled"`
AccountHealthSchedule string `json:"account_health_schedule"`
AccountHealthErrorRateThreshold float64 `json:"account_health_error_rate_threshold"`
}
// OpsEmailNotificationConfigUpdateRequest allows partial updates, while the
// frontend can still send the full config shape.
type OpsEmailNotificationConfigUpdateRequest struct {
Alert *OpsEmailAlertConfig `json:"alert"`
Report *OpsEmailReportConfig `json:"report"`
}
type OpsDistributedLockSettings struct {
Enabled bool `json:"enabled"`
Key string `json:"key"`
TTLSeconds int `json:"ttl_seconds"`
}
type OpsAlertSilenceEntry struct {
RuleID *int64 `json:"rule_id,omitempty"`
Severities []string `json:"severities,omitempty"`
UntilRFC3339 string `json:"until_rfc3339"`
Reason string `json:"reason"`
}
type OpsAlertSilencingSettings struct {
Enabled bool `json:"enabled"`
GlobalUntilRFC3339 string `json:"global_until_rfc3339"`
GlobalReason string `json:"global_reason"`
Entries []OpsAlertSilenceEntry `json:"entries,omitempty"`
}
type OpsMetricThresholds struct {
SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红
TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红
RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
}
type OpsAlertRuntimeSettings struct {
EvaluationIntervalSeconds int `json:"evaluation_interval_seconds"`
DistributedLock OpsDistributedLockSettings `json:"distributed_lock"`
Silencing OpsAlertSilencingSettings `json:"silencing"`
Thresholds OpsMetricThresholds `json:"thresholds"` // 指标阈值配置
}
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type OpsAdvancedSettings struct {
DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}
type OpsDataRetentionSettings struct {
CleanupEnabled bool `json:"cleanup_enabled"`
CleanupSchedule string `json:"cleanup_schedule"`
ErrorLogRetentionDays int `json:"error_log_retention_days"`
MinuteMetricsRetentionDays int `json:"minute_metrics_retention_days"`
HourlyMetricsRetentionDays int `json:"hourly_metrics_retention_days"`
}
type OpsAggregationSettings struct {
AggregationEnabled bool `json:"aggregation_enabled"`
}

View File

@@ -0,0 +1,65 @@
package service
import "time"
type OpsThroughputTrendPoint struct {
BucketStart time.Time `json:"bucket_start"`
RequestCount int64 `json:"request_count"`
TokenConsumed int64 `json:"token_consumed"`
QPS float64 `json:"qps"`
TPS float64 `json:"tps"`
}
type OpsThroughputPlatformBreakdownItem struct {
Platform string `json:"platform"`
RequestCount int64 `json:"request_count"`
TokenConsumed int64 `json:"token_consumed"`
}
type OpsThroughputGroupBreakdownItem struct {
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
RequestCount int64 `json:"request_count"`
TokenConsumed int64 `json:"token_consumed"`
}
type OpsThroughputTrendResponse struct {
Bucket string `json:"bucket"`
Points []*OpsThroughputTrendPoint `json:"points"`
// Optional drilldown helpers:
// - When no platform/group is selected: returns totals by platform.
// - When platform is selected but group is not: returns top groups in that platform.
ByPlatform []*OpsThroughputPlatformBreakdownItem `json:"by_platform,omitempty"`
TopGroups []*OpsThroughputGroupBreakdownItem `json:"top_groups,omitempty"`
}
type OpsErrorTrendPoint struct {
BucketStart time.Time `json:"bucket_start"`
ErrorCountTotal int64 `json:"error_count_total"`
BusinessLimitedCount int64 `json:"business_limited_count"`
ErrorCountSLA int64 `json:"error_count_sla"`
UpstreamErrorCountExcl429529 int64 `json:"upstream_error_count_excl_429_529"`
Upstream429Count int64 `json:"upstream_429_count"`
Upstream529Count int64 `json:"upstream_529_count"`
}
type OpsErrorTrendResponse struct {
Bucket string `json:"bucket"`
Points []*OpsErrorTrendPoint `json:"points"`
}
type OpsErrorDistributionItem struct {
StatusCode int `json:"status_code"`
Total int64 `json:"total"`
SLA int64 `json:"sla"`
BusinessLimited int64 `json:"business_limited"`
}
type OpsErrorDistributionResponse struct {
Total int64 `json:"total"`
Items []*OpsErrorDistributionItem `json:"items"`
}

View File

@@ -0,0 +1,26 @@
package service
import (
"context"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if filter == nil {
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
}
if filter.StartTime.After(filter.EndTime) {
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
}
return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds)
}

View File

@@ -0,0 +1,131 @@
package service
import (
"encoding/json"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// Gin context keys used by Ops error logger for capturing upstream error details.
// These keys are set by gateway services and consumed by handler/ops_error_logger.go.
const (
OpsUpstreamStatusCodeKey = "ops_upstream_status_code"
OpsUpstreamErrorMessageKey = "ops_upstream_error_message"
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
OpsUpstreamErrorsKey = "ops_upstream_errors"
// Best-effort capture of the current upstream request body so ops can
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
)
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
if c == nil {
return
}
if upstreamStatusCode > 0 {
c.Set(OpsUpstreamStatusCodeKey, upstreamStatusCode)
}
if msg := strings.TrimSpace(upstreamMessage); msg != "" {
c.Set(OpsUpstreamErrorMessageKey, msg)
}
if detail := strings.TrimSpace(upstreamDetail); detail != "" {
c.Set(OpsUpstreamErrorDetailKey, detail)
}
}
// OpsUpstreamErrorEvent describes one upstream error attempt during a single gateway request.
// It is stored in ops_error_logs.upstream_errors as a JSON array.
type OpsUpstreamErrorEvent struct {
AtUnixMs int64 `json:"at_unix_ms,omitempty"`
// Context
Platform string `json:"platform,omitempty"`
AccountID int64 `json:"account_id,omitempty"`
AccountName string `json:"account_name,omitempty"`
// Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
// Best-effort upstream response capture (sanitized+trimmed).
UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
// Kind: http_error | request_error | retry_exhausted | failover
Kind string `json:"kind,omitempty"`
Message string `json:"message,omitempty"`
Detail string `json:"detail,omitempty"`
}
func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
if c == nil {
return
}
if ev.AtUnixMs <= 0 {
ev.AtUnixMs = time.Now().UnixMilli()
}
ev.Platform = strings.TrimSpace(ev.Platform)
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind)
ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail)
if ev.Message != "" {
ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
}
// If the caller didn't explicitly pass upstream request body but the gateway
// stored it on the context, attach it so ops can retry this specific attempt.
if ev.UpstreamRequestBody == "" {
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
if s, ok := v.(string); ok {
ev.UpstreamRequestBody = strings.TrimSpace(s)
}
}
}
var existing []*OpsUpstreamErrorEvent
if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
existing = arr
}
}
evCopy := ev
existing = append(existing, &evCopy)
c.Set(OpsUpstreamErrorsKey, existing)
}
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
if len(events) == 0 {
return nil
}
// Ensure we always store a valid JSON value.
raw, err := json.Marshal(events)
if err != nil || len(raw) == 0 {
return nil
}
s := string(raw)
return &s
}
func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return []*OpsUpstreamErrorEvent{}, nil
}
var out []*OpsUpstreamErrorEvent
if err := json.Unmarshal([]byte(raw), &out); err != nil {
return nil, err
}
return out, nil
}

View File

@@ -0,0 +1,24 @@
package service
import (
"context"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// GetWindowStats returns lightweight request/token counts for the provided window.
// It is intended for realtime sampling (e.g. WebSocket QPS push) without computing percentiles/peaks.
func (s *OpsService) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
filter := &OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
}
return s.opsRepo.GetWindowStats(ctx, filter)
}

View File

@@ -0,0 +1,73 @@
package service
import (
"time"
)
// PromoCode 注册优惠码
type PromoCode struct {
ID int64
Code string
BonusAmount float64
MaxUses int
UsedCount int
Status string
ExpiresAt *time.Time
Notes string
CreatedAt time.Time
UpdatedAt time.Time
// 关联
UsageRecords []PromoCodeUsage
}
// PromoCodeUsage 优惠码使用记录
type PromoCodeUsage struct {
ID int64
PromoCodeID int64
UserID int64
BonusAmount float64
UsedAt time.Time
// 关联
PromoCode *PromoCode
User *User
}
// CanUse 检查优惠码是否可用
func (p *PromoCode) CanUse() bool {
if p.Status != PromoCodeStatusActive {
return false
}
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
return false
}
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
return false
}
return true
}
// IsExpired 检查是否已过期
func (p *PromoCode) IsExpired() bool {
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
}
// CreatePromoCodeInput 创建优惠码输入
type CreatePromoCodeInput struct {
Code string
BonusAmount float64
MaxUses int
ExpiresAt *time.Time
Notes string
}
// UpdatePromoCodeInput 更新优惠码输入
type UpdatePromoCodeInput struct {
Code *string
BonusAmount *float64
MaxUses *int
Status *string
ExpiresAt *time.Time
Notes *string
}

View File

@@ -0,0 +1,30 @@
package service
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// PromoCodeRepository 优惠码仓储接口
type PromoCodeRepository interface {
// 基础 CRUD
Create(ctx context.Context, code *PromoCode) error
GetByID(ctx context.Context, id int64) (*PromoCode, error)
GetByCode(ctx context.Context, code string) (*PromoCode, error)
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
Update(ctx context.Context, code *PromoCode) error
Delete(ctx context.Context, id int64) error
// 列表查询
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
// 使用记录
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
// 计数操作
IncrementUsedCount(ctx context.Context, id int64) error
}

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
)
// PromoService 优惠码服务
type PromoService struct {
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewPromoService 创建优惠码服务实例
func NewPromoService(
promoRepo PromoCodeRepository,
userRepo UserRepository,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *PromoService {
return &PromoService{
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
// ValidatePromoCode 验证优惠码(注册前调用)
// 返回 nil, nil 表示空码(不报错)
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
code = strings.TrimSpace(code)
if code == "" {
return nil, nil // 空码不报错,直接返回
}
promoCode, err := s.promoRepo.GetByCode(ctx, code)
if err != nil {
// 保留原始错误类型,不要统一映射为 NotFound
return nil, err
}
if err := s.validatePromoCodeStatus(promoCode); err != nil {
return nil, err
}
return promoCode, nil
}
// validatePromoCodeStatus 验证优惠码状态
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
if !promoCode.CanUse() {
if promoCode.IsExpired() {
return ErrPromoCodeExpired
}
if promoCode.Status == PromoCodeStatusDisabled {
return ErrPromoCodeDisabled
}
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
return ErrPromoCodeMaxUsed
}
return ErrPromoCodeInvalid
}
return nil
}
// ApplyPromoCode 应用优惠码(注册成功后调用)
// 使用事务和行锁确保并发安全
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
code = strings.TrimSpace(code)
if code == "" {
return nil
}
// 开启事务
tx, err := s.entClient.Tx(ctx)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
// 在事务中获取并锁定优惠码记录FOR UPDATE
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
if err != nil {
return err
}
// 在事务中验证优惠码状态
if err := s.validatePromoCodeStatus(promoCode); err != nil {
return err
}
// 在事务中检查用户是否已使用过此优惠码
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
if err != nil {
return fmt.Errorf("check existing usage: %w", err)
}
if existing != nil {
return ErrPromoCodeAlreadyUsed
}
// 增加用户余额
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
return fmt.Errorf("update user balance: %w", err)
}
// 创建使用记录
usage := &PromoCodeUsage{
PromoCodeID: promoCode.ID,
UserID: userID,
BonusAmount: promoCode.BonusAmount,
UsedAt: time.Now(),
}
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
return fmt.Errorf("create usage record: %w", err)
}
// 增加使用次数
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
return fmt.Errorf("increment used count: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
return nil
}
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
if bonusAmount == 0 || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GenerateRandomCode 生成随机优惠码
func (s *PromoService) GenerateRandomCode() (string, error) {
bytes := make([]byte, 8)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
return strings.ToUpper(hex.EncodeToString(bytes)), nil
}
// Create 创建优惠码
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
code := strings.TrimSpace(input.Code)
if code == "" {
// 自动生成
var err error
code, err = s.GenerateRandomCode()
if err != nil {
return nil, err
}
}
promoCode := &PromoCode{
Code: strings.ToUpper(code),
BonusAmount: input.BonusAmount,
MaxUses: input.MaxUses,
UsedCount: 0,
Status: PromoCodeStatusActive,
ExpiresAt: input.ExpiresAt,
Notes: input.Notes,
}
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
return nil, fmt.Errorf("create promo code: %w", err)
}
return promoCode, nil
}
// GetByID 根据ID获取优惠码
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
code, err := s.promoRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return code, nil
}
// Update 更新优惠码
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
promoCode, err := s.promoRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Code != nil {
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
}
if input.BonusAmount != nil {
promoCode.BonusAmount = *input.BonusAmount
}
if input.MaxUses != nil {
promoCode.MaxUses = *input.MaxUses
}
if input.Status != nil {
promoCode.Status = *input.Status
}
if input.ExpiresAt != nil {
promoCode.ExpiresAt = input.ExpiresAt
}
if input.Notes != nil {
promoCode.Notes = *input.Notes
}
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
return nil, fmt.Errorf("update promo code: %w", err)
}
return promoCode, nil
}
// Delete 删除优惠码
func (s *PromoService) Delete(ctx context.Context, id int64) error {
if err := s.promoRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete promo code: %w", err)
}
return nil
}
// List 获取优惠码列表
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
return s.promoRepo.ListWithFilters(ctx, params, status, search)
}
// ListUsages 获取使用记录
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
}

View File

@@ -0,0 +1,275 @@
You are a coding agent running in the Codex CLI, a terminal-based coding assistant. Codex CLI is an open source project led by OpenAI. You are expected to be precise, safe, and helpful.
Your capabilities:
- Receive user prompts and other context provided by the harness, such as files in the workspace.
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
- Emit function calls to run terminal commands and apply patches. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
# How you work
## Personality
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
# AGENTS.md spec
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
- Instructions in AGENTS.md files:
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
## Responsiveness
### Preamble messages
Before making tool calls, send a brief preamble to the user explaining what youre about to do. When sending preamble messages, follow these principles and examples:
- **Logically group related actions**: if youre about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (812 words for quick updates).
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with whats been done so far and create a sense of momentum and clarity for the user to understand your next actions.
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless its part of a larger grouped action.
**Examples:**
- “Ive explored the repo; now checking the API route definitions.”
- “Next, Ill patch the config and update the related tests.”
- “Im about to scaffold the CLI commands and helper functions.”
- “Ok cool, so Ive wrapped my head around the repo. Now digging into the API routes.”
- “Configs looking tidy. Next up is patching helpers to keep things in sync.”
- “Finished poking at the DB gateway. I will now chase down error handling.”
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
- “Spotted a clever caching util; now hunting where it gets used.”
## Planning
You have access to an `update_plan` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
Do not repeat the full contents of the plan after an `update_plan` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step. It may be the case that you complete all steps in your plan after a single pass of implementation. If this is the case, you can simply mark all the planned steps as completed. Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
Use a plan when:
- The task is non-trivial and will require multiple actions over a long time horizon.
- There are logical phases or dependencies where sequencing matters.
- The work has ambiguity that benefits from outlining high-level goals.
- You want intermediate checkpoints for feedback and validation.
- When the user asked you to do more than one thing in a single prompt
- The user has asked you to use the plan tool (aka "TODOs")
- You generate additional steps while working, and plan to do them before yielding to the user
### Examples
**High-quality plans**
Example 1:
1. Add CLI entry with file args
2. Parse Markdown via CommonMark library
3. Apply semantic HTML template
4. Handle code blocks, images, links
5. Add error handling for invalid files
Example 2:
1. Define CSS variables for colors
2. Add toggle with localStorage state
3. Refactor components to use variables
4. Verify all views for readability
5. Add smooth theme-change transition
Example 3:
1. Set up Node.js + WebSocket server
2. Add join/leave broadcast events
3. Implement messaging with timestamps
4. Add usernames + mention highlighting
5. Persist messages in lightweight DB
6. Add typing indicators + unread count
**Low-quality plans**
Example 1:
1. Create CLI tool
2. Add Markdown parser
3. Convert to HTML
Example 2:
1. Add dark mode toggle
2. Save preference
3. Make styles look good
Example 3:
1. Create single-file HTML game
2. Run quick sanity check
3. Summarize usage instructions
If you need to write a plan, only write high quality plans, not low quality ones.
## Task execution
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
You MUST adhere to the following criteria when solving queries:
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
- Analyzing code for vulnerabilities is allowed.
- Showing user code and tool call details is allowed.
- Use the `apply_patch` tool to edit files (NEVER try `applypatch` or `apply-patch`, only `apply_patch`): {"command":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]}
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
- Avoid unneeded complexity in your solution.
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
- Update documentation as necessary.
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
- NEVER add copyright or license headers unless specifically requested.
- Do not waste tokens by re-reading files after calling `apply_patch` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
- Do not `git commit` your changes or create new git branches unless explicitly requested.
- Do not add inline comments within code unless explicitly requested.
- Do not use one-letter variable names unless explicitly requested.
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
## Validating your work
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
## Ambition vs. precision
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
## Sharing progress updates
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
## Presenting your work and final message
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the users style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multi-section structured responses for results that need grouping or explanation.
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `apply_patch`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If theres something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
### Final answer structure and style guidelines
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
**Section Headers**
- Use only when they improve clarity — they are not mandatory for every answer.
- Choose descriptive names that fit the content
- Keep headers short (13 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
- Leave no blank line before the first bullet under a header.
- Section headers should only be used where they genuinely improve scanability; avoid fragmenting the answer.
**Bullets**
- Use `-` followed by a space for every bullet.
- Merge related points when possible; avoid a bullet for every trivial detail.
- Keep bullets to one line unless breaking for clarity is unavoidable.
- Group into short lists (46 bullets) ordered by importance.
- Use consistent keyword phrasing and formatting across sections.
**Monospace**
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
- Never mix monospace and bold markers; choose one based on whether its a keyword (`**`) or inline code/path (`` ` ``).
**File References**
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Line/column (1based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
**Structure**
- Place related bullets together; dont mix unrelated concepts in the same section.
- Order sections from general → specific → supporting info.
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
- Match structure to complexity:
- Multi-part or detailed results → use clear headers and grouped bullets.
- Simple results → minimal headers, possibly just a short list or paragraph.
**Tone**
- Keep the voice collaborative and natural, like a coding partner handing off work.
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
- Keep descriptions self-contained; dont refer to “above” or “below”.
- Use parallel structure in lists for consistency.
**Dont**
- Dont use literal words “bold” or “monospace” in the content.
- Dont nest bullets or create deep hierarchies.
- Dont output ANSI escape codes directly — the CLI renderer applies them.
- Dont cram unrelated keywords into a single bullet; split for clarity.
- Dont let keyword lists run long — wrap or reformat for scanability.
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with whats needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
# Tool Guidelines
## Shell commands
When using the shell, you must adhere to the following guidelines:
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
- Do not use python scripts to attempt to output larger chunks of a file.
## `update_plan`
A tool named `update_plan` is available to you. You can use it to keep an uptodate, stepbystep plan for the task.
To create a new plan, call `update_plan` with a short list of 1sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
When steps have been completed, use `update_plan` to mark each finished step as `completed` and the next step you are working on as `in_progress`. There should always be exactly one `in_progress` step until everything is done. You can mark multiple items as complete in a single `update_plan` call.
If all steps are complete, ensure you call `update_plan` to mark all steps as `completed`.

View File

@@ -0,0 +1,122 @@
# Codex Running in OpenCode
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
## CRITICAL: Tool Replacements
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan, read_plan, readPlan
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
## Available OpenCode Tools
**File Operations:**
- `write` - Create new files
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
- `edit` - Modify existing files (REPLACES apply_patch)
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
- `read` - Read file contents
**Search/Discovery:**
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
- `list` - List directories (requires absolute paths)
**Execution:**
- `bash` - Run shell commands
- No workdir parameter; do not include it in tool calls.
- Always include a short description for the command.
- Do not use cd; use absolute paths in commands.
- Quote paths containing spaces with double quotes.
- Chain multiple commands with ';' or '&&'; avoid newlines.
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
- For deletions (rm), verify by listing parent dir with `list`.
**Network:**
- `webfetch` - Fetch web content
- Use fully-formed URLs (http/https; http auto-upgrades to https).
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
- Read-only; short cache window.
**Task Management:**
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
- `todoread` - Read current plan
## Substitution Rules
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
**Path Usage:** Use per-tool conventions to avoid conflicts:
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
## Verification Checklist
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I following each tool's path requirements?
If ANY answer is NO → STOP and correct before proceeding.
## OpenCode Working Style
**Communication:**
- Send brief preambles (8-12 words) before tool calls, building on prior context
- Provide progress updates during longer tasks
**Execution:**
- Keep working autonomously until query is fully resolved before yielding
- Don't return to user with partial solutions
**Code Approach:**
- New projects: Be ambitious and creative
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
**Testing:**
- If tests exist: Start specific to your changes, then broader validation
## Advanced Tools
**Task Tool (Sub-Agents):**
- Use the Task tool (functions.task) to launch sub-agents
- Check the Task tool description for current agent types and their capabilities
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
- The agent list is dynamically generated - refer to tool schema for available agents
**Parallelization:**
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
- Reserve sequential calls for ordered or data-dependent steps.
**MCP Tools:**
- Model Context Protocol servers provide additional capabilities
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
- Check your available tools for MCP integrations
- Use when the tool's functionality matches your task needs
## What Remains from Codex
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
## Approvals & Safety
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).

View File

@@ -0,0 +1,63 @@
<user_instructions priority="0">
<environment_override priority="0">
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
</environment_override>
<tool_replacements priority="0">
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan
- ALWAYS use: todowrite for ALL task/plan operations
- Use todoread to read current plan
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
</tool_replacements>
<available_tools priority="0">
File Operations:
• write - Create new files
• edit - Modify existing files (REPLACES apply_patch)
• patch - Apply diff patches
• read - Read file contents
Search/Discovery:
• grep - Search file contents
• glob - Find files by pattern
• list - List directories (use relative paths)
Execution:
• bash - Run shell commands
Network:
• webfetch - Fetch web content
Task Management:
• todowrite - Manage tasks/plans (REPLACES update_plan)
• todoread - Read current plan
</available_tools>
<substitution_rules priority="0">
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
absolute paths → relative paths
</substitution_rules>
<verification_checklist priority="0">
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I using relative paths?
If ANY answer is NO → STOP and correct before proceeding.
</verification_checklist>
</user_instructions>

View File

@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
type ProxyWithAccountCount struct {
Proxy
AccountCount int64
AccountCount int64
LatencyMs *int64
LatencyStatus string
LatencyMessage string
IPAddress string
Country string
CountryCode string
Region string
City string
}
type ProxyAccountSummary struct {
ID int64
Name string
Platform string
Type string
Notes *string
}

View File

@@ -0,0 +1,23 @@
package service
import (
"context"
"time"
)
type ProxyLatencyInfo struct {
Success bool `json:"success"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyLatencyCache interface {
GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*ProxyLatencyInfo, error)
SetProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) error
}

View File

@@ -10,6 +10,7 @@ import (
var (
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts")
)
type ProxyRepository interface {
@@ -26,6 +27,7 @@ type ProxyRepository interface {
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
}
// CreateProxyRequest 创建代理请求

View File

@@ -3,7 +3,7 @@ package service
import (
"context"
"encoding/json"
"log"
"log/slog"
"net/http"
"strconv"
"strings"
@@ -15,13 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
@@ -44,43 +47,105 @@ func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogReposi
}
}
// SetTimeoutCounterCache 设置超时计数器缓存(可选依赖)
func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
s.timeoutCounterCache = cache
}
// SetSettingService 设置系统设置服务(可选依赖)
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
s.tokenCacheInvalidator = invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false
}
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
tempMatched := false
if statusCode != 401 {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" {
upstreamMsg = truncateForLog([]byte(upstreamMsg), 512)
}
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth {
// 1. 失效缓存
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
}
}
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 402:
// 支付要求:余额不足或计费问题,停止调度
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
msg := "Payment required (402): insufficient balance or billing issue"
if upstreamMsg != "" {
msg = "Payment required (402): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 429:
s.handle429(ctx, account, headers)
s.handle429(ctx, account, headers, responseBody)
shouldDisable = false
case 529:
s.handle529(ctx, account)
shouldDisable = false
default:
// 其他5xx错误记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
// 自定义错误码启用时:在列表中的错误码都应该停止调度
if customErrorCodesEnabled {
msg := "Custom error code triggered"
if upstreamMsg != "" {
msg = upstreamMsg
}
s.handleCustomErrorCode(ctx, account, statusCode, msg)
shouldDisable = true
} else if statusCode >= 500 {
// 未启用自定义错误码时仅记录5xx错误
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
shouldDisable = false
}
shouldDisable = false
}
if tempMatched {
@@ -125,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil {
return true, err
}
@@ -150,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -172,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil {
return true, err
}
@@ -193,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -250,22 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return
}
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return
}
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
}
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
@@ -273,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
resetAt := time.Unix(ts, 0)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
return
}
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
return
}
// 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
@@ -293,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
}
func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
if account == nil || account.Platform != PlatformAnthropic {
return false
}
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
if msg == "" {
return false
}
return strings.Contains(msg, "sonnet")
}
// handle529 处理529过载错误
@@ -309,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return
}
log.Printf("Account %d overloaded until %v", account.ID, until)
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
@@ -336,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
}
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
}
}
}
@@ -356,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
return err
}
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil {
return err
}
return s.accountRepo.ClearModelRateLimits(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
@@ -365,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
return nil
@@ -412,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
}
}
@@ -515,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
return true
}
@@ -538,3 +652,124 @@ func truncateTempUnschedMessage(body []byte, maxBytes int) string {
}
return strings.TrimSpace(string(body))
}
// HandleStreamTimeout 处理流数据超时
// 根据系统设置决定是否标记账户为临时不可调度或错误状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Account, model string) bool {
if account == nil {
return false
}
// 获取系统设置
if s.settingService == nil {
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
return false
}
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil {
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
return false
}
if !settings.Enabled {
return false
}
if settings.Action == StreamTimeoutActionNone {
return false
}
// 增加超时计数
var count int64 = 1
if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil {
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
// 继续处理,使用 count=1
count = 1
}
}
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
// 检查是否达到阈值
if count < int64(settings.ThresholdCount) {
return false
}
// 达到阈值,执行相应操作
switch settings.Action {
case StreamTimeoutActionTempUnsched:
return s.triggerStreamTimeoutTempUnsched(ctx, account, settings, model)
case StreamTimeoutActionError:
return s.triggerStreamTimeoutError(ctx, account, model)
default:
return false
}
}
// triggerStreamTimeoutTempUnsched 触发流超时临时不可调度
func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, account *Account, settings *StreamTimeoutSettings, model string) bool {
now := time.Now()
until := now.Add(time.Duration(settings.TempUnschedMinutes) * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: 0, // 超时没有状态码
MatchedKeyword: "stream_timeout",
RuleIndex: -1, // 表示系统级规则
ErrorMessage: "Stream data interval timeout for model: " + model,
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = state.ErrorMessage
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
}
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
return true
}
// triggerStreamTimeoutError 触发流超时错误状态
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
return true
}

View File

@@ -0,0 +1,121 @@
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Len(t, invalidator.accounts, 1)
})
}
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}

View File

@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewRedeemService 创建兑换码服务实例
@@ -84,14 +85,16 @@ func NewRedeemService(
cache RedeemCache,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// invalidateRedeemCaches 失效兑换相关的缓存
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
if s.billingCacheService == nil {
return
}
switch redeemCode.Type {
case RedeemTypeBalance:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
case RedeemTypeConcurrency:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
case RedeemTypeSubscription:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
if redeemCode.GroupID != nil {
groupID := *redeemCode.GroupID
go func() {

View File

@@ -0,0 +1,68 @@
package service
import (
"context"
"fmt"
"strconv"
"strings"
"time"
)
const (
SchedulerModeSingle = "single"
SchedulerModeMixed = "mixed"
SchedulerModeForced = "forced"
)
type SchedulerBucket struct {
GroupID int64
Platform string
Mode string
}
func (b SchedulerBucket) String() string {
return fmt.Sprintf("%d:%s:%s", b.GroupID, b.Platform, b.Mode)
}
func ParseSchedulerBucket(raw string) (SchedulerBucket, bool) {
parts := strings.Split(raw, ":")
if len(parts) != 3 {
return SchedulerBucket{}, false
}
groupID, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return SchedulerBucket{}, false
}
if parts[1] == "" || parts[2] == "" {
return SchedulerBucket{}, false
}
return SchedulerBucket{
GroupID: groupID,
Platform: parts[1],
Mode: parts[2],
}, true
}
// SchedulerCache 负责调度快照与账号快照的缓存读写。
type SchedulerCache interface {
// GetSnapshot 读取快照并返回命中与否ready + active + 数据完整)。
GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error)
// SetSnapshot 写入快照并切换激活版本。
SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error
// GetAccount 获取单账号快照。
GetAccount(ctx context.Context, accountID int64) (*Account, error)
// SetAccount 写入单账号快照(包含不可调度状态)。
SetAccount(ctx context.Context, account *Account) error
// DeleteAccount 删除单账号快照。
DeleteAccount(ctx context.Context, accountID int64) error
// UpdateLastUsed 批量更新账号的最后使用时间。
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
// TryLockBucket 尝试获取分桶重建锁。
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
// ListBuckets 返回已注册的分桶集合。
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
// GetOutboxWatermark 读取 outbox 水位。
GetOutboxWatermark(ctx context.Context) (int64, error)
// SetOutboxWatermark 保存 outbox 水位。
SetOutboxWatermark(ctx context.Context, id int64) error
}

View File

@@ -0,0 +1,10 @@
package service
const (
SchedulerOutboxEventAccountChanged = "account_changed"
SchedulerOutboxEventAccountGroupsChanged = "account_groups_changed"
SchedulerOutboxEventAccountBulkChanged = "account_bulk_changed"
SchedulerOutboxEventAccountLastUsed = "account_last_used"
SchedulerOutboxEventGroupChanged = "group_changed"
SchedulerOutboxEventFullRebuild = "full_rebuild"
)

View File

@@ -0,0 +1,21 @@
package service
import (
"context"
"time"
)
type SchedulerOutboxEvent struct {
ID int64
EventType string
AccountID *int64
GroupID *int64
Payload map[string]any
CreatedAt time.Time
}
// SchedulerOutboxRepository 提供调度 outbox 的读取接口。
type SchedulerOutboxRepository interface {
ListAfter(ctx context.Context, afterID int64, limit int) ([]SchedulerOutboxEvent, error)
MaxID(ctx context.Context) (int64, error)
}

View File

@@ -0,0 +1,786 @@
package service
import (
"context"
"encoding/json"
"errors"
"log"
"strconv"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var (
ErrSchedulerCacheNotReady = errors.New("scheduler cache not ready")
ErrSchedulerFallbackLimited = errors.New("scheduler db fallback limited")
)
const outboxEventTimeout = 2 * time.Minute
type SchedulerSnapshotService struct {
cache SchedulerCache
outboxRepo SchedulerOutboxRepository
accountRepo AccountRepository
groupRepo GroupRepository
cfg *config.Config
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
fallbackLimit *fallbackLimiter
lagMu sync.Mutex
lagFailures int
}
func NewSchedulerSnapshotService(
cache SchedulerCache,
outboxRepo SchedulerOutboxRepository,
accountRepo AccountRepository,
groupRepo GroupRepository,
cfg *config.Config,
) *SchedulerSnapshotService {
maxQPS := 0
if cfg != nil {
maxQPS = cfg.Gateway.Scheduling.DbFallbackMaxQPS
}
return &SchedulerSnapshotService{
cache: cache,
outboxRepo: outboxRepo,
accountRepo: accountRepo,
groupRepo: groupRepo,
cfg: cfg,
stopCh: make(chan struct{}),
fallbackLimit: newFallbackLimiter(maxQPS),
}
}
func (s *SchedulerSnapshotService) Start() {
if s == nil || s.cache == nil {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runInitialRebuild()
}()
interval := s.outboxPollInterval()
if s.outboxRepo != nil && interval > 0 {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runOutboxWorker(interval)
}()
}
fullInterval := s.fullRebuildInterval()
if fullInterval > 0 {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.runFullRebuildWorker(fullInterval)
}()
}
}
func (s *SchedulerSnapshotService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
mode := s.resolveMode(platform, hasForcePlatform)
bucket := s.bucketFor(groupID, platform, mode)
if s.cache != nil {
cached, hit, err := s.cache.GetSnapshot(ctx, bucket)
if err != nil {
log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err)
} else if hit {
return derefAccounts(cached), useMixed, nil
}
}
if err := s.guardFallback(ctx); err != nil {
return nil, useMixed, err
}
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
defer cancel()
accounts, err := s.loadAccountsFromDB(fallbackCtx, bucket, useMixed)
if err != nil {
return nil, useMixed, err
}
if s.cache != nil {
if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil {
log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err)
}
}
return accounts, useMixed, nil
}
func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
if accountID <= 0 {
return nil, nil
}
if s.cache != nil {
account, err := s.cache.GetAccount(ctx, accountID)
if err != nil {
log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err)
} else if account != nil {
return account, nil
}
}
if err := s.guardFallback(ctx); err != nil {
return nil, err
}
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
defer cancel()
return s.accountRepo.GetByID(fallbackCtx, accountID)
}
func (s *SchedulerSnapshotService) runInitialRebuild() {
if s.cache == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
buckets, err := s.cache.ListBuckets(ctx)
if err != nil {
log.Printf("[Scheduler] list buckets failed: %v", err)
}
if len(buckets) == 0 {
buckets, err = s.defaultBuckets(ctx)
if err != nil {
log.Printf("[Scheduler] default buckets failed: %v", err)
return
}
}
if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil {
log.Printf("[Scheduler] rebuild startup failed: %v", err)
}
}
func (s *SchedulerSnapshotService) runOutboxWorker(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
s.pollOutbox()
for {
select {
case <-ticker.C:
s.pollOutbox()
case <-s.stopCh:
return
}
}
}
func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.triggerFullRebuild("interval"); err != nil {
log.Printf("[Scheduler] full rebuild failed: %v", err)
}
case <-s.stopCh:
return
}
}
}
func (s *SchedulerSnapshotService) pollOutbox() {
if s.outboxRepo == nil || s.cache == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
watermark, err := s.cache.GetOutboxWatermark(ctx)
if err != nil {
log.Printf("[Scheduler] outbox watermark read failed: %v", err)
return
}
events, err := s.outboxRepo.ListAfter(ctx, watermark, 200)
if err != nil {
log.Printf("[Scheduler] outbox poll failed: %v", err)
return
}
if len(events) == 0 {
return
}
watermarkForCheck := watermark
for _, event := range events {
eventCtx, cancel := context.WithTimeout(context.Background(), outboxEventTimeout)
err := s.handleOutboxEvent(eventCtx, event)
cancel()
if err != nil {
log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
return
}
}
lastID := events[len(events)-1].ID
if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil {
log.Printf("[Scheduler] outbox watermark write failed: %v", err)
} else {
watermarkForCheck = lastID
}
s.checkOutboxLag(ctx, events[0], watermarkForCheck)
}
func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent) error {
switch event.EventType {
case SchedulerOutboxEventAccountLastUsed:
return s.handleLastUsedEvent(ctx, event.Payload)
case SchedulerOutboxEventAccountBulkChanged:
return s.handleBulkAccountEvent(ctx, event.Payload)
case SchedulerOutboxEventAccountGroupsChanged:
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
case SchedulerOutboxEventAccountChanged:
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
case SchedulerOutboxEventGroupChanged:
return s.handleGroupEvent(ctx, event.GroupID)
case SchedulerOutboxEventFullRebuild:
return s.triggerFullRebuild("outbox")
default:
return nil
}
}
func (s *SchedulerSnapshotService) handleLastUsedEvent(ctx context.Context, payload map[string]any) error {
if s.cache == nil || payload == nil {
return nil
}
raw, ok := payload["last_used"].(map[string]any)
if !ok || len(raw) == 0 {
return nil
}
updates := make(map[int64]time.Time, len(raw))
for key, value := range raw {
id, err := strconv.ParseInt(key, 10, 64)
if err != nil || id <= 0 {
continue
}
sec, ok := toInt64(value)
if !ok || sec <= 0 {
continue
}
updates[id] = time.Unix(sec, 0)
}
if len(updates) == 0 {
return nil
}
return s.cache.UpdateLastUsed(ctx, updates)
}
func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any) error {
if payload == nil {
return nil
}
ids := parseInt64Slice(payload["account_ids"])
for _, id := range ids {
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
return err
}
}
return nil
}
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
if accountID == nil || *accountID <= 0 {
return nil
}
if s.accountRepo == nil {
return nil
}
var groupIDs []int64
if payload != nil {
groupIDs = parseInt64Slice(payload["group_ids"])
}
account, err := s.accountRepo.GetByID(ctx, *accountID)
if err != nil {
if errors.Is(err, ErrAccountNotFound) {
if s.cache != nil {
if err := s.cache.DeleteAccount(ctx, *accountID); err != nil {
return err
}
}
return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss")
}
return err
}
if s.cache != nil {
if err := s.cache.SetAccount(ctx, account); err != nil {
return err
}
}
if len(groupIDs) == 0 {
groupIDs = account.GroupIDs
}
return s.rebuildByAccount(ctx, account, groupIDs, "account_change")
}
func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64) error {
if groupID == nil || *groupID <= 0 {
return nil
}
groupIDs := []int64{*groupID}
return s.rebuildByGroupIDs(ctx, groupIDs, "group_change")
}
func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string) error {
if account == nil {
return nil
}
groupIDs = s.normalizeGroupIDs(groupIDs)
if len(groupIDs) == 0 {
return nil
}
var firstErr error
if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason); err != nil && firstErr == nil {
firstErr = err
}
if account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason); err != nil && firstErr == nil {
firstErr = err
}
if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string) error {
groupIDs = s.normalizeGroupIDs(groupIDs)
if len(groupIDs) == 0 {
return nil
}
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
var firstErr error
for _, platform := range platforms {
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string) error {
if platform == "" {
return nil
}
var firstErr error
for _, gid := range groupIDs {
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeSingle}, reason); err != nil && firstErr == nil {
firstErr = err
}
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeForced}, reason); err != nil && firstErr == nil {
firstErr = err
}
if platform == PlatformAnthropic || platform == PlatformGemini {
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeMixed}, reason); err != nil && firstErr == nil {
firstErr = err
}
}
}
return firstErr
}
func (s *SchedulerSnapshotService) rebuildBuckets(ctx context.Context, buckets []SchedulerBucket, reason string) error {
var firstErr error
for _, bucket := range buckets {
if err := s.rebuildBucket(ctx, bucket, reason); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket SchedulerBucket, reason string) error {
if s.cache == nil {
return ErrSchedulerCacheNotReady
}
ok, err := s.cache.TryLockBucket(ctx, bucket, 30*time.Second)
if err != nil {
return err
}
if !ok {
return nil
}
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed)
if err != nil {
log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
return err
}
if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil {
log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
return err
}
log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts))
return nil
}
func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error {
if s.cache == nil {
return ErrSchedulerCacheNotReady
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
buckets, err := s.cache.ListBuckets(ctx)
if err != nil {
log.Printf("[Scheduler] list buckets failed: %v", err)
return err
}
if len(buckets) == 0 {
buckets, err = s.defaultBuckets(ctx)
if err != nil {
log.Printf("[Scheduler] default buckets failed: %v", err)
return err
}
}
return s.rebuildBuckets(ctx, buckets, reason)
}
func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest SchedulerOutboxEvent, watermark int64) {
if oldest.CreatedAt.IsZero() || s.cfg == nil {
return
}
lag := time.Since(oldest.CreatedAt)
if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 {
log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds)
}
if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds {
s.lagMu.Lock()
s.lagFailures++
failures := s.lagFailures
s.lagMu.Unlock()
if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures {
log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures)
s.lagMu.Lock()
s.lagFailures = 0
s.lagMu.Unlock()
if err := s.triggerFullRebuild("outbox_lag"); err != nil {
log.Printf("[Scheduler] outbox lag rebuild failed: %v", err)
}
}
} else {
s.lagMu.Lock()
s.lagFailures = 0
s.lagMu.Unlock()
}
threshold := s.cfg.Gateway.Scheduling.OutboxBacklogRebuildRows
if threshold <= 0 || s.outboxRepo == nil {
return
}
maxID, err := s.outboxRepo.MaxID(ctx)
if err != nil {
return
}
if maxID-watermark >= int64(threshold) {
log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark)
if err := s.triggerFullRebuild("outbox_backlog"); err != nil {
log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err)
}
}
}
func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucket SchedulerBucket, useMixed bool) ([]Account, error) {
if s.accountRepo == nil {
return nil, ErrSchedulerCacheNotReady
}
groupID := bucket.GroupID
if s.isRunModeSimple() {
groupID = 0
}
if useMixed {
platforms := []string{bucket.Platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID > 0 {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, nil
}
if groupID > 0 {
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
}
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
}
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {
return SchedulerBucket{
GroupID: s.normalizeGroupID(groupID),
Platform: platform,
Mode: mode,
}
}
func (s *SchedulerSnapshotService) normalizeGroupID(groupID *int64) int64 {
if s.isRunModeSimple() {
return 0
}
if groupID == nil || *groupID <= 0 {
return 0
}
return *groupID
}
func (s *SchedulerSnapshotService) normalizeGroupIDs(groupIDs []int64) []int64 {
if s.isRunModeSimple() {
return []int64{0}
}
if len(groupIDs) == 0 {
return []int64{0}
}
seen := make(map[int64]struct{}, len(groupIDs))
out := make([]int64, 0, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
if len(out) == 0 {
return []int64{0}
}
return out
}
func (s *SchedulerSnapshotService) resolveMode(platform string, hasForcePlatform bool) string {
if hasForcePlatform {
return SchedulerModeForced
}
if platform == PlatformAnthropic || platform == PlatformGemini {
return SchedulerModeMixed
}
return SchedulerModeSingle
}
func (s *SchedulerSnapshotService) guardFallback(ctx context.Context) error {
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackEnabled {
if s.fallbackLimit == nil || s.fallbackLimit.Allow() {
return nil
}
return ErrSchedulerFallbackLimited
}
return ErrSchedulerCacheNotReady
}
func (s *SchedulerSnapshotService) withFallbackTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds <= 0 {
return context.WithCancel(ctx)
}
timeout := time.Duration(s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds) * time.Second
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if remaining <= 0 {
return context.WithCancel(ctx)
}
if remaining < timeout {
timeout = remaining
}
}
return context.WithTimeout(ctx, timeout)
}
func (s *SchedulerSnapshotService) isRunModeSimple() bool {
return s.cfg != nil && s.cfg.RunMode == config.RunModeSimple
}
func (s *SchedulerSnapshotService) outboxPollInterval() time.Duration {
if s.cfg == nil {
return time.Second
}
sec := s.cfg.Gateway.Scheduling.OutboxPollIntervalSeconds
if sec <= 0 {
return time.Second
}
return time.Duration(sec) * time.Second
}
func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
if s.cfg == nil {
return 0
}
sec := s.cfg.Gateway.Scheduling.FullRebuildIntervalSeconds
if sec <= 0 {
return 0
}
return time.Duration(sec) * time.Second
}
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
buckets := make([]SchedulerBucket, 0)
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
for _, platform := range platforms {
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
if platform == PlatformAnthropic || platform == PlatformGemini {
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeMixed})
}
}
if s.isRunModeSimple() || s.groupRepo == nil {
return dedupeBuckets(buckets), nil
}
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return dedupeBuckets(buckets), nil
}
for _, group := range groups {
if group.Platform == "" {
continue
}
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeSingle})
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeForced})
if group.Platform == PlatformAnthropic || group.Platform == PlatformGemini {
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeMixed})
}
}
return dedupeBuckets(buckets), nil
}
func dedupeBuckets(in []SchedulerBucket) []SchedulerBucket {
seen := make(map[string]struct{}, len(in))
out := make([]SchedulerBucket, 0, len(in))
for _, bucket := range in {
key := bucket.String()
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, bucket)
}
return out
}
func derefAccounts(accounts []*Account) []Account {
if len(accounts) == 0 {
return []Account{}
}
out := make([]Account, 0, len(accounts))
for _, account := range accounts {
if account == nil {
continue
}
out = append(out, *account)
}
return out
}
func parseInt64Slice(value any) []int64 {
raw, ok := value.([]any)
if !ok {
return nil
}
out := make([]int64, 0, len(raw))
for _, item := range raw {
if v, ok := toInt64(item); ok && v > 0 {
out = append(out, v)
}
}
return out
}
func toInt64(value any) (int64, bool) {
switch v := value.(type) {
case float64:
return int64(v), true
case int64:
return v, true
case int:
return int64(v), true
case json.Number:
parsed, err := strconv.ParseInt(v.String(), 10, 64)
return parsed, err == nil
default:
return 0, false
}
}
type fallbackLimiter struct {
maxQPS int
mu sync.Mutex
window time.Time
count int
}
func newFallbackLimiter(maxQPS int) *fallbackLimiter {
if maxQPS <= 0 {
return nil
}
return &fallbackLimiter{
maxQPS: maxQPS,
window: time.Now(),
}
}
func (l *fallbackLimiter) Allow() bool {
if l == nil || l.maxQPS <= 0 {
return true
}
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
if now.Sub(l.window) >= time.Second {
l.window = now
l.count = 0
}
if l.count >= l.maxQPS {
return false
}
l.count++
return true
}

View File

@@ -0,0 +1,63 @@
package service
import (
"context"
"time"
)
// SessionLimitCache 管理账号级别的活跃会话跟踪
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
//
// Key 格式: session_limit:account:{accountID}
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
//
// 会话在空闲超时后自动过期,无需手动清理
type SessionLimitCache interface {
// RegisterSession 注册会话活动
// - 如果会话已存在,刷新其时间戳并返回 true
// - 如果会话不存在且活跃会话数 < maxSessions添加新会话并返回 true
// - 如果会话不存在且活跃会话数 >= maxSessions返回 false拒绝
//
// 参数:
// accountID: 账号 ID
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
// maxSessions: 最大并发会话数限制
// idleTimeout: 会话空闲超时时间
//
// 返回:
// allowed: true 表示允许在限制内或会话已存在false 表示拒绝(超出限制且是新会话)
// error: 操作错误
RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
// RefreshSession 刷新现有会话的时间戳
// 用于活跃会话保持活动状态
RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
// GetActiveSessionCount 获取当前活跃会话数
// 返回未过期的会话数量
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// 返回 map[accountID]count查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
// ========== 5h窗口费用缓存 ==========
// Key 格式: window_cost:account:{accountID}
// 用于缓存账号在当前5h窗口内的标准费用减少数据库聚合查询压力
// GetWindowCost 获取缓存的窗口费用
// 返回 (cost, true, nil) 如果缓存命中
// 返回 (0, false, nil) 如果缓存未命中
// 返回 (0, false, err) 如果发生错误
GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
// SetWindowCost 设置窗口费用缓存
SetWindowCost(ctx context.Context, accountID int64, cost float64) error
// GetWindowCostBatch 批量获取窗口费用缓存
// 返回 map[accountID]cost缓存未命中的账号不在 map 中
GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
}

Some files were not shown because too many files have changed in this diff Show More