Merge upstream/main
This commit is contained in:
@@ -9,16 +9,19 @@ import (
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
|
||||
// BillingRateMultiplier 返回账号计费倍率。
|
||||
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
|
||||
// - 允许 0,表示该账号计费为 0
|
||||
// - 负数属于非法数据,出于安全考虑按 1.0 处理
|
||||
func (a *Account) BillingRateMultiplier() float64 {
|
||||
if a == nil || a.RateMultiplier == nil {
|
||||
return 1.0
|
||||
}
|
||||
if *a.RateMultiplier < 0 {
|
||||
return 1.0
|
||||
}
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
@@ -540,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WindowCostSchedulability 窗口费用调度状态
|
||||
type WindowCostSchedulability int
|
||||
|
||||
const (
|
||||
// WindowCostSchedulable 可正常调度
|
||||
WindowCostSchedulable WindowCostSchedulability = iota
|
||||
// WindowCostStickyOnly 仅允许粘性会话
|
||||
WindowCostStickyOnly
|
||||
// WindowCostNotSchedulable 完全不可调度
|
||||
WindowCostNotSchedulable
|
||||
)
|
||||
|
||||
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
|
||||
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
|
||||
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["window_cost_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
|
||||
// 默认值为 10
|
||||
func (a *Account) GetWindowCostStickyReserve() float64 {
|
||||
if a.Extra == nil {
|
||||
return 10.0
|
||||
}
|
||||
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
|
||||
val := parseExtraFloat64(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 10.0
|
||||
}
|
||||
|
||||
// GetMaxSessions 获取最大并发会话数
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetMaxSessions() int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["max_sessions"]; ok {
|
||||
return parseExtraInt(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
|
||||
// 默认值为 5 分钟
|
||||
func (a *Account) GetSessionIdleTimeoutMinutes() int {
|
||||
if a.Extra == nil {
|
||||
return 5
|
||||
}
|
||||
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
|
||||
val := parseExtraInt(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 5
|
||||
}
|
||||
|
||||
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
|
||||
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
|
||||
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
|
||||
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
|
||||
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
|
||||
limit := a.GetWindowCostLimit()
|
||||
if limit <= 0 {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
if currentWindowCost < limit {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
stickyReserve := a.GetWindowCostStickyReserve()
|
||||
if currentWindowCost < limit+stickyReserve {
|
||||
return WindowCostStickyOnly
|
||||
}
|
||||
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
if f, err := v.Float64(); err == nil {
|
||||
return f
|
||||
}
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseExtraInt 从 extra 字段解析 int 值
|
||||
func parseExtraInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
|
||||
var a Account
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
|
||||
require.Nil(t, a.RateMultiplier)
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
|
||||
v := 0.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 0.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
|
||||
v := -1.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
@@ -51,11 +51,13 @@ type AccountRepository interface {
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||
ClearModelRateLimits(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
@@ -64,14 +66,15 @@ type AccountRepository interface {
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
|
||||
@@ -147,6 +147,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
|
||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
panic("unexpected SetModelRateLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
@@ -167,6 +171,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
|
||||
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearModelRateLimits call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
panic("unexpected UpdateSessionWindow call")
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
|
||||
}
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
//
|
||||
// cost: 账号口径费用(total_cost * account_rate_multiplier)
|
||||
// standard_cost: 标准费用(total_cost,不含倍率)
|
||||
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
|
||||
type WindowStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
StandardCost float64 `json:"standard_cost"`
|
||||
UserCost float64 `json:"user_cost"`
|
||||
}
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
windowStats = &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}
|
||||
|
||||
// 缓存窗口统计(1 分钟)
|
||||
@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
}
|
||||
|
||||
return &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
|
||||
// 用于账号列表页面显示当前窗口费用
|
||||
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,8 @@ type AdminService interface {
|
||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
||||
DeleteProxy(ctx context.Context, id int64) error
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
|
||||
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||
|
||||
@@ -106,6 +107,9 @@ type CreateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -125,6 +129,9 @@ type UpdateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -137,6 +144,7 @@ type CreateAccountInput struct {
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
@@ -152,8 +160,9 @@ type UpdateAccountInput struct {
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
@@ -163,16 +172,17 @@ type UpdateAccountInput struct {
|
||||
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
AccountIDs []int64
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
@@ -187,9 +197,11 @@ type BulkUpdateAccountResult struct {
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
Results []BulkUpdateAccountResult `json:"results"`
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
SuccessIDs []int64 `json:"success_ids"`
|
||||
FailedIDs []int64 `json:"failed_ids"`
|
||||
Results []BulkUpdateAccountResult `json:"results"`
|
||||
}
|
||||
|
||||
type CreateProxyInput struct {
|
||||
@@ -219,23 +231,35 @@ type GenerateRedeemCodesInput struct {
|
||||
ValidityDays int // 订阅类型专用:有效天数
|
||||
}
|
||||
|
||||
// ProxyTestResult represents the result of testing a proxy
|
||||
type ProxyTestResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
type ProxyBatchDeleteResult struct {
|
||||
DeletedIDs []int64 `json:"deleted_ids"`
|
||||
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
||||
type ProxyBatchDeleteSkipped struct {
|
||||
ID int64 `json:"id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// ProxyTestResult represents the result of testing a proxy
|
||||
type ProxyTestResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ip-api.com
|
||||
type ProxyExitInfo struct {
|
||||
IP string
|
||||
City string
|
||||
Region string
|
||||
Country string
|
||||
IP string
|
||||
City string
|
||||
Region string
|
||||
Country string
|
||||
CountryCode string
|
||||
}
|
||||
|
||||
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||
@@ -245,14 +269,16 @@ type ProxyExitInfoProber interface {
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
@@ -265,16 +291,20 @@ func NewAdminService(
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,6 +354,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
}
|
||||
|
||||
oldConcurrency := user.Concurrency
|
||||
oldStatus := user.Status
|
||||
oldRole := user.Role
|
||||
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
@@ -356,6 +388,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
@@ -394,6 +431,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
||||
return err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -421,6 +461,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if s.authCacheInvalidator != nil && balanceDiff != 0 {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
@@ -432,7 +476,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
}()
|
||||
}
|
||||
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
@@ -545,6 +588,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -577,18 +621,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
return fmt.Errorf("cannot set self as fallback group")
|
||||
}
|
||||
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
visited := map[int64]struct{}{}
|
||||
nextID := fallbackGroupID
|
||||
for {
|
||||
if _, seen := visited[nextID]; seen {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[nextID] = struct{}{}
|
||||
if currentGroupID > 0 && nextID == currentGroupID {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
|
||||
if fallbackGroup.FallbackGroupID == nil {
|
||||
return nil
|
||||
}
|
||||
nextID = *fallbackGroup.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
@@ -658,13 +717,32 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
}
|
||||
|
||||
// 模型路由配置
|
||||
if input.ModelRouting != nil {
|
||||
group.ModelRouting = input.ModelRouting
|
||||
}
|
||||
if input.ModelRoutingEnabled != nil {
|
||||
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
var groupKeys []string
|
||||
if s.authCacheInvalidator != nil {
|
||||
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
|
||||
if err == nil {
|
||||
groupKeys = keys
|
||||
}
|
||||
}
|
||||
|
||||
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -683,6 +761,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
}
|
||||
}()
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, key := range groupKeys {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -769,6 +852,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier < 0 {
|
||||
return nil, errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -821,6 +910,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Priority != nil {
|
||||
account.Priority = *input.Priority
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier < 0 {
|
||||
return nil, errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
@@ -871,7 +966,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
result := &BulkUpdateAccountsResult{
|
||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||
}
|
||||
|
||||
if len(input.AccountIDs) == 0 {
|
||||
@@ -892,6 +989,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
}
|
||||
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier < 0 {
|
||||
return nil, errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
@@ -909,6 +1012,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.Priority != nil {
|
||||
repoUpdates.Priority = input.Priority
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
@@ -935,6 +1041,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
@@ -944,6 +1051,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
@@ -953,6 +1061,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
@@ -960,6 +1069,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
|
||||
entry.Success = true
|
||||
result.Success++
|
||||
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
}
|
||||
|
||||
@@ -1019,6 +1129,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
s.attachProxyLatency(ctx, proxies)
|
||||
return proxies, result.Total, nil
|
||||
}
|
||||
|
||||
@@ -1027,7 +1138,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||
return s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.attachProxyLatency(ctx, proxies)
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
||||
@@ -1047,6 +1163,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Probe latency asynchronously so creation isn't blocked by network timeout.
|
||||
go s.probeProxyLatency(context.Background(), proxy)
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
@@ -1085,12 +1203,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
||||
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return ErrProxyInUse
|
||||
}
|
||||
return s.proxyRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
|
||||
// Return mock data for now - would need a dedicated repository method
|
||||
return []Account{}, 0, nil
|
||||
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
|
||||
result := &ProxyBatchDeleteResult{}
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||
if err != nil {
|
||||
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||
ID: id,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if count > 0 {
|
||||
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||
ID: id,
|
||||
Reason: ErrProxyInUse.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if err := s.proxyRepo.Delete(ctx, id); err != nil {
|
||||
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||
ID: id,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
result.DeletedIDs = append(result.DeletedIDs, id)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
@@ -1190,23 +1349,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
proxyURL := proxy.URL()
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||
if err != nil {
|
||||
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||
Success: false,
|
||||
Message: err.Error(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
latency := latencyMs
|
||||
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||
Success: true,
|
||||
LatencyMs: &latency,
|
||||
Message: "Proxy is accessible",
|
||||
IPAddress: exitInfo.IP,
|
||||
Country: exitInfo.Country,
|
||||
CountryCode: exitInfo.CountryCode,
|
||||
Region: exitInfo.Region,
|
||||
City: exitInfo.City,
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: exitInfo.IP,
|
||||
City: exitInfo.City,
|
||||
Region: exitInfo.Region,
|
||||
Country: exitInfo.Country,
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: exitInfo.IP,
|
||||
City: exitInfo.City,
|
||||
Region: exitInfo.Region,
|
||||
Country: exitInfo.Country,
|
||||
CountryCode: exitInfo.CountryCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
||||
if s.proxyProber == nil || proxy == nil {
|
||||
return
|
||||
}
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
|
||||
if err != nil {
|
||||
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||
Success: false,
|
||||
Message: err.Error(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
latency := latencyMs
|
||||
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||
Success: true,
|
||||
LatencyMs: &latency,
|
||||
Message: "Proxy is accessible",
|
||||
IPAddress: exitInfo.IP,
|
||||
Country: exitInfo.Country,
|
||||
CountryCode: exitInfo.CountryCode,
|
||||
Region: exitInfo.Region,
|
||||
City: exitInfo.City,
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
||||
// 如果存在混合,返回错误提示用户确认
|
||||
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
@@ -1256,6 +1461,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
ids = append(ids, proxies[i].ID)
|
||||
}
|
||||
|
||||
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
|
||||
if err != nil {
|
||||
log.Printf("Warning: load proxy latency cache failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range proxies {
|
||||
info := latencies[proxies[i].ID]
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
if info.Success {
|
||||
proxies[i].LatencyStatus = "success"
|
||||
proxies[i].LatencyMs = info.LatencyMs
|
||||
} else {
|
||||
proxies[i].LatencyStatus = "failed"
|
||||
}
|
||||
proxies[i].LatencyMessage = info.Message
|
||||
proxies[i].IPAddress = info.IPAddress
|
||||
proxies[i].Country = info.Country
|
||||
proxies[i].CountryCode = info.CountryCode
|
||||
proxies[i].Region = info.Region
|
||||
proxies[i].City = info.City
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
|
||||
if s.proxyLatencyCache == nil || info == nil {
|
||||
return
|
||||
}
|
||||
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
||||
log.Printf("Warning: store proxy latency cache failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
||||
func getAccountPlatform(accountPlatform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
||||
|
||||
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForBulkUpdate struct {
|
||||
accountRepoStub
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
if s.bulkUpdateErr != nil {
|
||||
return 0, s.bulkUpdateErr
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
|
||||
require.Empty(t, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
|
||||
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
bindGroupErrByID: map[int64]error{
|
||||
2: errors.New("bind failed"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
schedulable := false
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
GroupIDs: &groupIDs,
|
||||
Schedulable: &schedulable,
|
||||
SkipMixedChannelCheck: true,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 1, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
@@ -149,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
deletedIDs []int64
|
||||
deleteErr error
|
||||
countErr error
|
||||
accountCount int64
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
|
||||
@@ -195,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
panic("unexpected CountAccountsByProxyID call")
|
||||
if s.countErr != nil {
|
||||
return 0, s.countErr
|
||||
}
|
||||
return s.accountCount, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
panic("unexpected ListAccountSummariesByProxyID call")
|
||||
}
|
||||
|
||||
type redeemRepoStub struct {
|
||||
@@ -405,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
|
||||
require.Equal(t, []int64{404}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
|
||||
repo := &proxyRepoStub{accountCount: 2}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 77)
|
||||
require.ErrorIs(t, err, ErrProxyInUse)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &proxyRepoStub{deleteErr: deleteErr}
|
||||
|
||||
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
fallbackID := int64(2)
|
||||
repo := &groupRepoStubForFallbackCycle{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
FallbackGroupID: &fallbackID,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
FallbackGroupID: &groupID,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
type groupRepoStubForFallbackCycle struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type balanceUserRepoStub struct {
|
||||
*userRepoStub
|
||||
updateErr error
|
||||
updated []*User
|
||||
}
|
||||
|
||||
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *user
|
||||
s.updated = append(s.updated, &clone)
|
||||
if s.userRepoStub != nil {
|
||||
s.userRepoStub.user = &clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceRedeemRepoStub struct {
|
||||
*redeemRepoStub
|
||||
created []*RedeemCode
|
||||
}
|
||||
|
||||
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *code
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
type authCacheInvalidatorStub struct {
|
||||
userIDs []int64
|
||||
groupIDs []int64
|
||||
keys []string
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
s.keys = append(s.keys, key)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
s.userIDs = append(s.userIDs, userID)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
s.groupIDs = append(s.groupIDs, groupID)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
require.Len(t, redeemRepo.created, 1)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
require.Empty(t, redeemRepo.created)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
mathrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -28,6 +29,8 @@ const (
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
|
||||
|
||||
// antigravityRetryLoopParams 重试循环的参数
|
||||
type antigravityRetryLoopParams struct {
|
||||
ctx context.Context
|
||||
@@ -38,7 +41,9 @@ type antigravityRetryLoopParams struct {
|
||||
action string
|
||||
body []byte
|
||||
quotaScope AntigravityQuotaScope
|
||||
c *gin.Context
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
|
||||
}
|
||||
|
||||
@@ -56,6 +61,17 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
|
||||
|
||||
var resp *http.Response
|
||||
var usedBaseURL string
|
||||
logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
getUpstreamDetail := func(body []byte) string {
|
||||
if !logBody {
|
||||
return ""
|
||||
}
|
||||
return truncateString(string(body), maxBytes)
|
||||
}
|
||||
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
@@ -73,8 +89,22 @@ urlFallbackLoop:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if p.c != nil && len(p.body) > 0 {
|
||||
p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
|
||||
}
|
||||
|
||||
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||
Platform: p.account.Platform,
|
||||
AccountID: p.account.ID,
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
@@ -89,6 +119,7 @@ urlFallbackLoop:
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err)
|
||||
setOpsUpstreamError(p.c, 0, safeErr, "")
|
||||
return nil, fmt.Errorf("upstream request failed after retries: %w", err)
|
||||
}
|
||||
|
||||
@@ -99,13 +130,37 @@ urlFallbackLoop:
|
||||
|
||||
// "Resource has been exhausted" 是 URL 级别限流,切换 URL
|
||||
if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||
Platform: p.account.Platform,
|
||||
AccountID: p.account.ID,
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
antigravity.DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", p.prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200))
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
|
||||
// 账户/模型配额限流,重试 3 次(指数退避)
|
||||
if attempt < antigravityMaxRetries {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||
Platform: p.account.Platform,
|
||||
AccountID: p.account.ID,
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
@@ -131,6 +186,18 @@ urlFallbackLoop:
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||
Platform: p.account.Platform,
|
||||
AccountID: p.account.ID,
|
||||
AccountName: p.account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
@@ -679,6 +746,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
|
||||
sanitizeThinkingBlocks(&claudeReq)
|
||||
|
||||
// 获取转换选项
|
||||
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
|
||||
transformOpts := s.getClaudeTransformOptions(ctx)
|
||||
@@ -690,6 +760,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
|
||||
// Safety net: ensure no cache_control leaked into Gemini request
|
||||
geminiBody = cleanCacheControlFromGeminiJSON(geminiBody)
|
||||
|
||||
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
|
||||
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
|
||||
action := "streamGenerateContent"
|
||||
@@ -704,7 +777,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
action: action,
|
||||
body: geminiBody,
|
||||
quotaScope: quotaScope,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
handleError: s.handleUpstreamError,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -720,6 +795,28 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "signature_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
// Conservative two-stage fallback:
|
||||
// 1) Disable top-level thinking + thinking->text
|
||||
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
|
||||
@@ -753,6 +850,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr != nil {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "signature_retry_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||
})
|
||||
log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr)
|
||||
continue
|
||||
}
|
||||
@@ -766,6 +871,26 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
kind := "signature_retry"
|
||||
if strings.TrimSpace(stage.name) != "" {
|
||||
kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_")
|
||||
}
|
||||
retryUpstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(retryBody))
|
||||
retryUpstreamMsg = sanitizeUpstreamErrorMessage(retryUpstreamMsg)
|
||||
retryUpstreamDetail := ""
|
||||
if logBody {
|
||||
retryUpstreamDetail = truncateString(string(retryBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: retryResp.StatusCode,
|
||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||
Kind: kind,
|
||||
Message: retryUpstreamMsg,
|
||||
Detail: retryUpstreamDetail,
|
||||
})
|
||||
|
||||
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
|
||||
if retryResp.StatusCode != http.StatusBadRequest || !isSignatureRelatedError(retryBody) {
|
||||
@@ -793,10 +918,31 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -879,6 +1025,143 @@ func extractAntigravityErrorMessage(body []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
|
||||
// This should not be needed if transformation is correct, but serves as a safety net
|
||||
func cleanCacheControlFromGeminiJSON(body []byte) []byte {
|
||||
// Try a more robust approach: parse and clean
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err)
|
||||
return body
|
||||
}
|
||||
|
||||
cleaned := removeCacheControlFromAny(data)
|
||||
if !cleaned {
|
||||
return body
|
||||
}
|
||||
|
||||
if result, err := json.Marshal(data); err == nil {
|
||||
log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON")
|
||||
return result
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// removeCacheControlFromAny recursively removes cache_control fields
|
||||
func removeCacheControlFromAny(v any) bool {
|
||||
cleaned := false
|
||||
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
for k, child := range val {
|
||||
if k == "cache_control" {
|
||||
delete(val, k)
|
||||
cleaned = true
|
||||
} else if removeCacheControlFromAny(child) {
|
||||
cleaned = true
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range val {
|
||||
if removeCacheControlFromAny(item) {
|
||||
cleaned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cleaned
|
||||
}
|
||||
|
||||
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
|
||||
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
|
||||
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
|
||||
func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages))
|
||||
|
||||
// Clean system blocks
|
||||
if len(req.System) > 0 {
|
||||
var systemBlocks []map[string]any
|
||||
if err := json.Unmarshal(req.System, &systemBlocks); err == nil {
|
||||
for i := range systemBlocks {
|
||||
if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil {
|
||||
if removeCacheControlFromAny(systemBlocks[i]) {
|
||||
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Marshal back
|
||||
if cleaned, err := json.Marshal(systemBlocks); err == nil {
|
||||
req.System = cleaned
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean message content blocks and flatten history
|
||||
lastMsgIdx := len(req.Messages) - 1
|
||||
for msgIdx := range req.Messages {
|
||||
raw := req.Messages[msgIdx].Content
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to parse as blocks array
|
||||
var blocks []map[string]any
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
cleaned := false
|
||||
for blockIdx := range blocks {
|
||||
blockType, _ := blocks[blockIdx]["type"].(string)
|
||||
|
||||
// Check for thinking blocks (typed or untyped)
|
||||
if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil {
|
||||
// 1. Clean cache_control
|
||||
if removeCacheControlFromAny(blocks[blockIdx]) {
|
||||
log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx)
|
||||
cleaned = true
|
||||
}
|
||||
|
||||
// 2. Flatten to text if it's a history message (not the last one)
|
||||
if msgIdx < lastMsgIdx {
|
||||
log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx)
|
||||
|
||||
// Extract thinking content
|
||||
var textContent string
|
||||
if t, ok := blocks[blockIdx]["thinking"].(string); ok {
|
||||
textContent = t
|
||||
} else {
|
||||
// Fallback for non-string content (marshal it)
|
||||
if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil {
|
||||
textContent = string(b)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to text block
|
||||
blocks[blockIdx]["type"] = "text"
|
||||
blocks[blockIdx]["text"] = textContent
|
||||
delete(blocks[blockIdx], "thinking")
|
||||
delete(blocks[blockIdx], "signature")
|
||||
delete(blocks[blockIdx], "cache_control") // Ensure it's gone
|
||||
cleaned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal back if modified
|
||||
if cleaned {
|
||||
if marshaled, err := json.Marshal(blocks); err == nil {
|
||||
req.Messages[msgIdx].Content = marshaled
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||
// This preserves the thinking content while avoiding signature validation errors.
|
||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||
@@ -1184,7 +1467,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
action: upstreamAction,
|
||||
body: wrappedBody,
|
||||
quotaScope: quotaScope,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
handleError: s.handleUpstreamError,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1234,22 +1519,62 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 解包并返回错误
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
|
||||
unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody)
|
||||
unwrappedForOps := unwrapped
|
||||
if unwrapErr != nil || len(unwrappedForOps) == 0 {
|
||||
unwrappedForOps = respBody
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(unwrappedForOps), maxBytes)
|
||||
}
|
||||
|
||||
// Always record upstream context for Ops error logs, even when we will failover.
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(respBody, 500))
|
||||
c.Data(resp.StatusCode, contentType, unwrapped)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500))
|
||||
c.Data(resp.StatusCode, contentType, unwrappedForOps)
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -1338,9 +1663,15 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func antigravityUseScopeRateLimit() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv)))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||
if statusCode == 429 {
|
||||
useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != ""
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
if resetAt == nil {
|
||||
// 解析失败:使用配置的 fallback 时间,直接限流整个账户
|
||||
@@ -1350,19 +1681,30 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
|
||||
}
|
||||
defaultDur := time.Duration(fallbackMinutes) * time.Minute
|
||||
ra := time.Now().Add(defaultDur)
|
||||
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||||
if useScopeLimit {
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||
if quotaScope == "" {
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||
if useScopeLimit {
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||
if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err)
|
||||
}
|
||||
} else {
|
||||
log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1533,6 +1875,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
@@ -1824,9 +2167,36 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int,
|
||||
return fmt.Errorf("%s", message)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
|
||||
// 记录上游错误详情便于调试
|
||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, string(body))
|
||||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: upstreamStatus,
|
||||
UpstreamRequestID: upstreamRequestID,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
|
||||
if logBody {
|
||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes))
|
||||
}
|
||||
|
||||
var statusCode int
|
||||
var errType, errMsg string
|
||||
@@ -1862,7 +2232,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstr
|
||||
"type": "error",
|
||||
"error": gin.H{"type": errType, "message": errMsg},
|
||||
})
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
if upstreamMsg == "" {
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
}
|
||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||
@@ -2189,6 +2562,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout (antigravity)")
|
||||
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
if !a.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
if a.isModelRateLimited(requestedModel) {
|
||||
return false
|
||||
}
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("not an antigravity oauth account")
|
||||
}
|
||||
|
||||
cacheKey := antigravityTokenCacheKey(account)
|
||||
cacheKey := AntigravityTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func antigravityTokenCacheKey(account *Account) string {
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "ag:" + projectID
|
||||
|
||||
@@ -3,16 +3,18 @@ package service
|
||||
import "time"
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
|
||||
51
backend/internal/service/api_key_auth_cache.go
Normal file
51
backend/internal/service/api_key_auth_cache.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package service
|
||||
|
||||
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||
type APIKeyAuthSnapshot struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
User APIKeyAuthUserSnapshot `json:"user"`
|
||||
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
type APIKeyAuthUserSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
}
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
type APIKeyAuthCacheEntry struct {
|
||||
NotFound bool `json:"not_found"`
|
||||
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||
}
|
||||
273
backend/internal/service/api_key_auth_cache_impl.go
Normal file
273
backend/internal/service/api_key_auth_cache_impl.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
l1TTL time.Duration
|
||||
l2TTL time.Duration
|
||||
negativeTTL time.Duration
|
||||
jitterPercent int
|
||||
singleflight bool
|
||||
}
|
||||
|
||||
var (
|
||||
jitterRandMu sync.Mutex
|
||||
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
)
|
||||
|
||||
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||
if cfg == nil {
|
||||
return apiKeyAuthCacheConfig{}
|
||||
}
|
||||
auth := cfg.APIKeyAuth
|
||||
return apiKeyAuthCacheConfig{
|
||||
l1Size: auth.L1Size,
|
||||
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||
jitterPercent: auth.JitterPercent,
|
||||
singleflight: auth.Singleflight,
|
||||
}
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||
return c.l1Size > 0 && c.l1TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||
return c.l2TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||
return c.negativeTTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return ttl
|
||||
}
|
||||
if c.jitterPercent <= 0 {
|
||||
return ttl
|
||||
}
|
||||
percent := c.jitterPercent
|
||||
if percent > 100 {
|
||||
percent = 100
|
||||
}
|
||||
delta := float64(percent) / 100
|
||||
jitterRandMu.Lock()
|
||||
randVal := jitterRand.Float64()
|
||||
jitterRandMu.Unlock()
|
||||
factor := 1 - delta + randVal*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
}
|
||||
return time.Duration(float64(ttl) * factor)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||
if !s.authCfg.l1Enabled() {
|
||||
return
|
||||
}
|
||||
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||
MaxCost: int64(s.authCfg.l1Size),
|
||||
BufferItems: 64,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||
if s.authCacheL1 != nil {
|
||||
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||
return entry, true
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return nil, false
|
||||
}
|
||||
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||
if s.authCacheL1 == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
ttl := s.authCfg.l1TTL
|
||||
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||
ttl = s.authCfg.negativeTTL
|
||||
}
|
||||
ttl = s.authCfg.jitterTTL(ttl)
|
||||
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return
|
||||
}
|
||||
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
if s.authCacheL1 != nil {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||
if s.authCfg.negativeEnabled() {
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||
if snapshot == nil {
|
||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||
}
|
||||
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||
if entry == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if entry.NotFound {
|
||||
return nil, true, ErrAPIKeyNotFound
|
||||
}
|
||||
if entry.Snapshot == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
if apiKey == nil || apiKey.User == nil {
|
||||
return nil
|
||||
}
|
||||
snapshot := &APIKeyAuthSnapshot{
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: apiKey.UserID,
|
||||
GroupID: apiKey.GroupID,
|
||||
Status: apiKey.Status,
|
||||
IPWhitelist: apiKey.IPWhitelist,
|
||||
IPBlacklist: apiKey.IPBlacklist,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
Role: apiKey.User.Role,
|
||||
Balance: apiKey.User.Balance,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey := &APIKey{
|
||||
ID: snapshot.APIKeyID,
|
||||
UserID: snapshot.UserID,
|
||||
GroupID: snapshot.GroupID,
|
||||
Key: key,
|
||||
Status: snapshot.Status,
|
||||
IPWhitelist: snapshot.IPWhitelist,
|
||||
IPBlacklist: snapshot.IPBlacklist,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
Role: snapshot.User.Role,
|
||||
Balance: snapshot.User.Balance,
|
||||
Concurrency: snapshot.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
}
|
||||
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
cacheKey := s.authCacheKey(key)
|
||||
s.deleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
if userID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
if groupID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -20,6 +23,7 @@ var (
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,9 +33,11 @@ const (
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
@@ -43,6 +49,8 @@ type APIKeyRepository interface {
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
@@ -53,20 +61,35 @@ type APIKeyCache interface {
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
type APIKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
@@ -77,6 +100,9 @@ type APIKeyService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -88,7 +114,7 @@ func NewAPIKeyService(
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -96,6 +122,8 @@ func NewAPIKeyService(
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
@@ -186,6 +214,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
@@ -236,17 +278,21 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("create api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -282,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
cacheKey := s.authCacheKey(key)
|
||||
|
||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
||||
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCfg.singleflight {
|
||||
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
} else {
|
||||
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.cache != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
|
||||
apiKey.Key = key
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -312,6 +386,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
@@ -344,19 +432,22 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 IP 限制(空数组会清空设置)
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
@@ -366,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
return ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 清除Redis缓存(使用 ownerID 而非 apiKey.UserID)
|
||||
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
s.InvalidateAuthCacheByKey(ctx, key)
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
|
||||
423
backend/internal/service/api_key_service_cache_test.go
Normal file
423
backend/internal/service/api_key_service_cache_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type authRepoStub struct {
|
||||
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
panic("unexpected GetKeyAndOwnerID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
if s.getByKeyForAuth == nil {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
return s.getByKeyForAuth(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
if s.listKeysByUserID == nil {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
return s.listKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
if s.listKeysByGroupID == nil {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
return s.listKeysByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
setAuthKeys []string
|
||||
deleteAuthKeys []string
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
if s.getAuthCache == nil {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return s.getAuthCache(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
Snapshot: &APIKeyAuthSnapshot{
|
||||
APIKeyID: 1,
|
||||
UserID: 2,
|
||||
GroupID: &groupID,
|
||||
Status: StatusActive,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
},
|
||||
Group: &APIKeyAuthGroupSnapshot{
|
||||
ID: groupID,
|
||||
Name: "g",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-opus-*": {1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return cacheEntry, nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), apiKey.ID)
|
||||
require.Equal(t, int64(2), apiKey.User.ID)
|
||||
require.Equal(t, groupID, apiKey.Group.ID)
|
||||
require.True(t, apiKey.Group.ModelRoutingEnabled)
|
||||
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return &APIKey{
|
||||
ID: 5,
|
||||
UserID: 7,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 12,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(5), apiKey.ID)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &APIKey{
|
||||
ID: 21,
|
||||
UserID: 3,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 3,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 5,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L1Size: 1000,
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
svc.authCacheL1.Wait()
|
||||
cacheKey := svc.authCacheKey("k-l1")
|
||||
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||
require.True(t, ok)
|
||||
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, ErrAPIKeyNotFound
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return &APIKey{
|
||||
ID: 11,
|
||||
UserID: 2,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 1,
|
||||
Concurrency: 1,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
errs := make([]error, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, err := svc.GetByKey(context.Background(), "k1")
|
||||
errs[idx] = err
|
||||
}(i)
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
for _, err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
@@ -20,13 +20,12 @@ import (
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
ownerID int64 // GetOwnerID 的返回值
|
||||
ownerErr error // GetOwnerID 的错误返回值
|
||||
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||
}
|
||||
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return nil, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
clone := *s.apiKey
|
||||
return &clone, nil
|
||||
}
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
// GetOwnerID 返回预设的所有者 ID 或错误。
|
||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
|
||||
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return s.ownerID, s.ownerErr
|
||||
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return "", 0, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||
}
|
||||
return "", 0, ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||
type apiKeyCacheStub struct {
|
||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||
}
|
||||
|
||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 1
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 7
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回正确的所有者 ID
|
||||
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
ownerID: 3,
|
||||
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
require.ErrorContains(t, err, "delete api key")
|
||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &UsageService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, false)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, true)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
}
|
||||
|
||||
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &RedeemService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
|
||||
groupID := int64(3)
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
|
||||
|
||||
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
|
||||
}
|
||||
@@ -52,6 +52,7 @@ type AuthService struct {
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -62,6 +63,7 @@ func NewAuthService(
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
@@ -70,16 +72,17 @@ func NewAuthService(
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -150,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
user = updatedUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
@@ -341,7 +357,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||
// - 如果邮箱不存在:创建新用户并登录
|
||||
//
|
||||
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。
|
||||
// 注意:该函数用于 LinuxDo OAuth 登录场景(不同于上游账号的 OAuth,例如 Claude/OpenAI/Gemini)。
|
||||
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
@@ -360,8 +376,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册。
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
// OAuth 首次登录视为注册(fail-close:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
)
|
||||
}
|
||||
|
||||
@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
208
backend/internal/service/claude_token_provider.go
Normal file
208
backend/internal/service/claude_token_provider.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeTokenRefreshSkew = 3 * time.Minute
|
||||
claudeTokenCacheSkew = 5 * time.Minute
|
||||
claudeLockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
}
|
||||
|
||||
func NewClaudeTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache ClaudeTokenCache,
|
||||
oauthService *OAuthService,
|
||||
) *ClaudeTokenProvider {
|
||||
return &ClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
939
backend/internal/service/claude_token_provider_test.go
Normal file
939
backend/internal/service/claude_token_provider_test.go
Normal file
@@ -0,0 +1,939 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// claudeTokenCacheStub implements ClaudeTokenCache for testing
|
||||
type claudeTokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
|
||||
return &claudeTokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
|
||||
type claudeAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// claudeOAuthServiceStub implements OAuthService methods for testing
|
||||
type claudeOAuthServiceStub struct {
|
||||
tokenInfo *TokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
// testClaudeTokenProvider is a test version that uses the stub OAuth service
|
||||
type testClaudeTokenProvider struct {
|
||||
accountRepo *claudeAccountRepoStub
|
||||
tokenCache *claudeTokenCacheStub
|
||||
oauthService *claudeOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// Build new credentials
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > claudeTokenCacheSkew {
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
getErr: errors.New("db connection failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 113,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still work, just using the passed-in account
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
updateErr: errors.New("db write failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 114,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still return token even if update fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-access-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
"custom_field": "should-be-preserved",
|
||||
"organization": "test-org",
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new-access-token", token)
|
||||
|
||||
// Verify existing fields are preserved
|
||||
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
|
||||
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
|
||||
// Verify new fields are updated
|
||||
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
|
||||
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 116,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// After lock is acquired, cache should have the token (simulating another worker)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other-worker"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set)
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
// After lock wait, return token from credentials
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cacheKey := "claude:account:303"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 303,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 304,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 305,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 306,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
258
backend/internal/service/dashboard_aggregation_service.go
Normal file
258
backend/internal/service/dashboard_aggregation_service.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDashboardAggregationTimeout = 2 * time.Minute
|
||||
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
|
||||
dashboardAggregationRetentionInterval = 6 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
)
|
||||
|
||||
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||
type DashboardAggregationRepository interface {
|
||||
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
// DashboardAggregationService 负责定时聚合与回填。
|
||||
type DashboardAggregationService struct {
|
||||
repo DashboardAggregationRepository
|
||||
timingWheel *TimingWheelService
|
||||
cfg config.DashboardAggregationConfig
|
||||
running int32
|
||||
lastRetentionCleanup atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// NewDashboardAggregationService 创建聚合服务。
|
||||
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
var aggCfg config.DashboardAggregationConfig
|
||||
if cfg != nil {
|
||||
aggCfg = cfg.DashboardAgg
|
||||
}
|
||||
return &DashboardAggregationService{
|
||||
repo: repo,
|
||||
timingWheel: timingWheel,
|
||||
cfg: aggCfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时聚合作业(重启生效配置)。
|
||||
func (s *DashboardAggregationService) Start() {
|
||||
if s == nil || s.repo == nil || s.timingWheel == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
log.Printf("[DashboardAggregation] 聚合作业已禁用")
|
||||
return
|
||||
}
|
||||
|
||||
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = time.Minute
|
||||
}
|
||||
|
||||
if s.cfg.RecomputeDays > 0 {
|
||||
go s.recomputeRecentDays()
|
||||
}
|
||||
|
||||
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||
s.runScheduledAggregation()
|
||||
})
|
||||
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerBackfill 触发回填(异步)。
|
||||
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||
return ErrDashboardBackfillDisabled
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
if s.cfg.BackfillMaxDays > 0 {
|
||||
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
|
||||
if end.Sub(start) > maxRange {
|
||||
return ErrDashboardBackfillTooLarge
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 回填失败: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
days := s.cfg.RecomputeDays
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
start := now.AddDate(0, 0, -days)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
|
||||
last = time.Unix(0, 0).UTC()
|
||||
}
|
||||
|
||||
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
start := last.Add(-lookback)
|
||||
if !last.After(epoch) {
|
||||
retentionDays := s.cfg.Retention.UsageLogsDays
|
||||
if retentionDays <= 0 {
|
||||
retentionDays = 1
|
||||
}
|
||||
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
|
||||
} else if start.After(now) {
|
||||
start = now.Add(-lookback)
|
||||
}
|
||||
|
||||
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
|
||||
if updateErr != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
start.Format(time.RFC3339),
|
||||
now.Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
updateErr == nil,
|
||||
)
|
||||
|
||||
s.maybeCleanupRetention(ctx, now)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errors.New("聚合作业正在运行")
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
|
||||
cursor := truncateToDayUTC(startUTC)
|
||||
for cursor.Before(endUTC) {
|
||||
windowEnd := cursor.Add(24 * time.Hour)
|
||||
if windowEnd.After(endUTC) {
|
||||
windowEnd = endUTC
|
||||
}
|
||||
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
cursor = windowEnd
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
|
||||
if updateErr != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
startUTC.Format(time.RFC3339),
|
||||
endUTC.Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
updateErr == nil,
|
||||
)
|
||||
|
||||
s.maybeCleanupRetention(ctx, endUTC)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if !end.After(start) {
|
||||
return nil
|
||||
}
|
||||
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
|
||||
}
|
||||
return s.repo.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
|
||||
lastAny := s.lastRetentionCleanup.Load()
|
||||
if lastAny != nil {
|
||||
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||
}
|
||||
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||
if usageErr != nil {
|
||||
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
func truncateToDayUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
s.aggregateCalls++
|
||||
s.lastStart = start
|
||||
s.lastEnd = end
|
||||
return s.aggregateErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return s.cleanupAggregatesErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
require.False(t, repo.lastEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
BackfillEnabled: true,
|
||||
BackfillMaxDays: 1,
|
||||
},
|
||||
}
|
||||
|
||||
start := time.Now().AddDate(0, 0, -3)
|
||||
end := time.Now()
|
||||
err := svc.TriggerBackfill(start, end)
|
||||
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
|
||||
require.Equal(t, 0, repo.aggregateCalls)
|
||||
}
|
||||
@@ -2,47 +2,307 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
// DashboardService provides aggregated statistics for admin dashboard.
|
||||
type DashboardService struct {
|
||||
usageRepo UsageLogRepository
|
||||
const (
|
||||
defaultDashboardStatsFreshTTL = 15 * time.Second
|
||||
defaultDashboardStatsCacheTTL = 30 * time.Second
|
||||
defaultDashboardStatsRefreshTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
|
||||
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
|
||||
|
||||
// DashboardStatsCache 定义仪表盘统计缓存接口。
|
||||
type DashboardStatsCache interface {
|
||||
GetDashboardStats(ctx context.Context) (string, error)
|
||||
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
|
||||
DeleteDashboardStats(ctx context.Context) error
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
|
||||
type dashboardStatsRangeFetcher interface {
|
||||
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
|
||||
}
|
||||
|
||||
type dashboardStatsCacheEntry struct {
|
||||
Stats *usagestats.DashboardStats `json:"stats"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DashboardService 提供管理员仪表盘统计服务。
|
||||
type DashboardService struct {
|
||||
usageRepo UsageLogRepository
|
||||
aggRepo DashboardAggregationRepository
|
||||
cache DashboardStatsCache
|
||||
cacheFreshTTL time.Duration
|
||||
cacheTTL time.Duration
|
||||
refreshTimeout time.Duration
|
||||
refreshing int32
|
||||
aggEnabled bool
|
||||
aggInterval time.Duration
|
||||
aggLookback time.Duration
|
||||
aggUsageDays int
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
|
||||
freshTTL := defaultDashboardStatsFreshTTL
|
||||
cacheTTL := defaultDashboardStatsCacheTTL
|
||||
refreshTimeout := defaultDashboardStatsRefreshTimeout
|
||||
aggEnabled := true
|
||||
aggInterval := time.Minute
|
||||
aggLookback := 2 * time.Minute
|
||||
aggUsageDays := 90
|
||||
if cfg != nil {
|
||||
if !cfg.Dashboard.Enabled {
|
||||
cache = nil
|
||||
}
|
||||
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
|
||||
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsTTLSeconds > 0 {
|
||||
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
|
||||
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
|
||||
}
|
||||
aggEnabled = cfg.DashboardAgg.Enabled
|
||||
if cfg.DashboardAgg.IntervalSeconds > 0 {
|
||||
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.LookbackSeconds > 0 {
|
||||
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
|
||||
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
|
||||
}
|
||||
}
|
||||
if aggRepo == nil {
|
||||
aggEnabled = false
|
||||
}
|
||||
return &DashboardService{
|
||||
usageRepo: usageRepo,
|
||||
usageRepo: usageRepo,
|
||||
aggRepo: aggRepo,
|
||||
cache: cache,
|
||||
cacheFreshTTL: freshTTL,
|
||||
cacheTTL: cacheTTL,
|
||||
refreshTimeout: refreshTimeout,
|
||||
aggEnabled: aggEnabled,
|
||||
aggInterval: aggInterval,
|
||||
aggLookback: aggLookback,
|
||||
aggUsageDays: aggUsageDays,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.usageRepo.GetDashboardStats(ctx)
|
||||
if s.cache != nil {
|
||||
cached, fresh, err := s.getCachedDashboardStats(ctx)
|
||||
if err == nil && cached != nil {
|
||||
s.refreshAggregationStaleness(cached)
|
||||
if !fresh {
|
||||
s.refreshDashboardStatsAsync()
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.refreshDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||
data, err := s.cache.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||
s.evictDashboardStatsCache(err)
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
if entry.Stats == nil {
|
||||
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
age := time.Since(time.Unix(entry.UpdatedAt, 0))
|
||||
return entry.Stats, age <= s.cacheFreshTTL, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreInt32(&s.refreshing, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
if !s.aggEnabled {
|
||||
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
|
||||
now := time.Now().UTC()
|
||||
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
|
||||
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
|
||||
}
|
||||
}
|
||||
return s.usageRepo.GetDashboardStats(ctx)
|
||||
}
|
||||
|
||||
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if s.cache == nil || stats == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||
}
|
||||
if reason != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
}
|
||||
|
||||
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := s.fetchAggregationUpdatedAt(ctx)
|
||||
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
|
||||
if s.aggRepo == nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
if updatedAt.IsZero() {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return updatedAt.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
|
||||
if !s.aggEnabled {
|
||||
return true
|
||||
}
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
if !updatedAt.After(epoch) {
|
||||
return true
|
||||
}
|
||||
threshold := s.aggInterval + s.aggLookback
|
||||
return now.Sub(updatedAt) > threshold
|
||||
}
|
||||
|
||||
func parseStatsUpdatedAt(raw string) time.Time {
|
||||
if raw == "" {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
parsed, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return parsed.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
|
||||
387
backend/internal/service/dashboard_service_test.go
Normal file
387
backend/internal/service/dashboard_service_test.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type usageRepoStub struct {
|
||||
UsageLogRepository
|
||||
stats *usagestats.DashboardStats
|
||||
rangeStats *usagestats.DashboardStats
|
||||
err error
|
||||
rangeErr error
|
||||
calls int32
|
||||
rangeCalls int32
|
||||
rangeStart time.Time
|
||||
rangeEnd time.Time
|
||||
onCall chan struct{}
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.onCall != nil {
|
||||
select {
|
||||
case s.onCall <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.rangeCalls, 1)
|
||||
s.rangeStart = start
|
||||
s.rangeEnd = end
|
||||
if s.rangeErr != nil {
|
||||
return nil, s.rangeErr
|
||||
}
|
||||
if s.rangeStats != nil {
|
||||
return s.rangeStats, nil
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
type dashboardCacheStub struct {
|
||||
get func(ctx context.Context) (string, error)
|
||||
set func(ctx context.Context, data string, ttl time.Duration) error
|
||||
del func(ctx context.Context) error
|
||||
getCalls int32
|
||||
setCalls int32
|
||||
delCalls int32
|
||||
lastSetMu sync.Mutex
|
||||
lastSet string
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
atomic.AddInt32(&c.getCalls, 1)
|
||||
if c.get != nil {
|
||||
return c.get(ctx)
|
||||
}
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&c.setCalls, 1)
|
||||
c.lastSetMu.Lock()
|
||||
c.lastSet = data
|
||||
c.lastSetMu.Unlock()
|
||||
if c.set != nil {
|
||||
return c.set(ctx, data, ttl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
|
||||
atomic.AddInt32(&c.delCalls, 1)
|
||||
if c.del != nil {
|
||||
return c.del(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dashboardAggregationRepoStub struct {
|
||||
watermark time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
if s.err != nil {
|
||||
return time.Time{}, s.err
|
||||
}
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
|
||||
t.Helper()
|
||||
c.lastSetMu.Lock()
|
||||
data := c.lastSet
|
||||
c.lastSetMu.Unlock()
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
err := json.Unmarshal([]byte(data), &entry)
|
||||
require.NoError(t, err)
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitFresh(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 10,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 99},
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 7,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
|
||||
entry := cache.readLastEntry(t)
|
||||
require.Equal(t, stats, entry.Stats)
|
||||
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 3,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
|
||||
staleStats := &usagestats.DashboardStats{
|
||||
TotalUsers: 11,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: staleStats,
|
||||
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
refreshCh := make(chan struct{}, 1)
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 22},
|
||||
onCall: refreshCh,
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, staleStats, got)
|
||||
|
||||
select {
|
||||
case <-refreshCh:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("等待异步刷新超时")
|
||||
}
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt32(&cache.setCalls) >= 1
|
||||
}, 1*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
stats := &usagestats.DashboardStats{TotalUsers: 9}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{err: errors.New("db down")}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
_, err := svc.GetDashboardStats(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
|
||||
require.True(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
|
||||
aggNow := time.Now().UTC().Truncate(time.Second)
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
|
||||
require.False(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
|
||||
expected := &usagestats.DashboardStats{TotalUsers: 42}
|
||||
repo := &usageRepoStub{
|
||||
rangeStats: expected,
|
||||
err: errors.New("should not call aggregated stats"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: false,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 7,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, nil, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), got.TotalUsers)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
|
||||
require.False(t, repo.rangeEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
|
||||
}
|
||||
@@ -38,6 +38,12 @@ const (
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
@@ -57,6 +63,9 @@ const (
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
@@ -77,6 +86,12 @@ const (
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// LinuxDo Connect OAuth 登录设置
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
@@ -84,6 +99,7 @@ const (
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
@@ -106,16 +122,38 @@ const (
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
)
|
||||
// =========================
|
||||
// Ops Monitoring (vNext)
|
||||
// =========================
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// SettingKeyOpsMonitoringEnabled is a DB-backed soft switch to enable/disable ops module at runtime.
|
||||
SettingKeyOpsMonitoringEnabled = "ops_monitoring_enabled"
|
||||
|
||||
// SettingKeyOpsRealtimeMonitoringEnabled controls realtime features (e.g. WS/QPS push).
|
||||
SettingKeyOpsRealtimeMonitoringEnabled = "ops_realtime_monitoring_enabled"
|
||||
|
||||
// SettingKeyOpsQueryModeDefault controls the default query mode for ops dashboard (auto/raw/preagg).
|
||||
SettingKeyOpsQueryModeDefault = "ops_query_mode_default"
|
||||
|
||||
// SettingKeyOpsEmailNotificationConfig stores JSON config for ops email notifications.
|
||||
SettingKeyOpsEmailNotificationConfig = "ops_email_notification_config"
|
||||
|
||||
// SettingKeyOpsAlertRuntimeSettings stores JSON config for ops alert evaluator runtime settings.
|
||||
SettingKeyOpsAlertRuntimeSettings = "ops_alert_runtime_settings"
|
||||
|
||||
// SettingKeyOpsMetricsIntervalSeconds controls the ops metrics collector interval (>=60).
|
||||
SettingKeyOpsMetricsIntervalSeconds = "ops_metrics_interval_seconds"
|
||||
|
||||
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
|
||||
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
|
||||
|
||||
// =========================
|
||||
// Stream Timeout Handling
|
||||
// =========================
|
||||
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
|
||||
getByIDCalls int
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
m.getByIDCalls++
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
@@ -142,6 +145,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -157,6 +163,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -194,6 +203,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockGroupRepoForGateway struct {
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -900,6 +959,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
|
||||
return m.accountWaitCounts[accountID], nil
|
||||
}
|
||||
|
||||
type mockConcurrencyCache struct {
|
||||
acquireAccountCalls int
|
||||
loadBatchCalls int
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
m.acquireAccountCalls++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
m.loadBatchCalls++
|
||||
result := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
result[acc.ID] = &AccountLoadInfo{
|
||||
AccountID: acc.ID,
|
||||
CurrentConcurrency: 0,
|
||||
WaitingCount: 0,
|
||||
LoadRate: 0,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -928,13 +1055,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil, // No concurrency service
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
sessionHash := "sticky"
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{sessionHash: 1},
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-a": {1},
|
||||
"claude-b": {2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil, // legacy path
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
|
||||
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
|
||||
})
|
||||
|
||||
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
@@ -959,7 +1140,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -991,13 +1172,85 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||
})
|
||||
|
||||
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"sticky": 1},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
concurrencyCache := &mockConcurrencyCache{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID)
|
||||
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
|
||||
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
|
||||
})
|
||||
|
||||
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"sticky": 1},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
concurrencyCache := &mockConcurrencyCache{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
|
||||
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
|
||||
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
|
||||
})
|
||||
|
||||
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{},
|
||||
@@ -1016,9 +1269,264 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
|
||||
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
|
||||
})
|
||||
|
||||
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
overloadUntil := now.Add(10 * time.Minute)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
ctxGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
|
||||
groupID := int64(42)
|
||||
invalidGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
hydratedGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
|
||||
svc := &GatewayService{}
|
||||
ctx = svc.withGroupContext(ctx, hydratedGroup)
|
||||
|
||||
got, ok := ctx.Value(ctxkey.Group).(*Group)
|
||||
require.True(t, ok)
|
||||
require.Same(t, hydratedGroup, got)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
Hydrated: true,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{fallbackID: fallbackGroup},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &groupID,
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: group,
|
||||
fallbackID: fallbackGroup,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
|
||||
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gotGroup)
|
||||
require.Nil(t, gotID)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,7 @@ type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cache GatewayCache
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
@@ -51,6 +52,7 @@ func NewGeminiMessagesCompatService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cache GatewayCache,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
tokenProvider *GeminiTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
@@ -61,6 +63,7 @@ func NewGeminiMessagesCompatService(
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
@@ -86,9 +89,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
@@ -99,12 +108,6 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
var queryPlatforms []string
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
||||
} else {
|
||||
queryPlatforms = []string{platform}
|
||||
}
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
@@ -112,7 +115,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
@@ -143,22 +146,16 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
|
||||
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
@@ -239,6 +236,31 @@ func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *Antigravit
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
return accounts, err
|
||||
}
|
||||
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
queryPlatforms := []string{platform}
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{platform, PlatformAntigravity}
|
||||
}
|
||||
|
||||
if groupID != nil {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
@@ -260,13 +282,7 @@ func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (strin
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformAntigravity, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -282,13 +298,7 @@ func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context
|
||||
// 3) OAuth accounts explicitly marked as ai_studio
|
||||
// 4) Any remaining Gemini accounts (fallback)
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, PlatformGemini, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
@@ -535,14 +545,30 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = idHeader
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
// In this code path `body` is already the JSON sent to upstream.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
if attempt < geminiMaxRetries {
|
||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr)
|
||||
}
|
||||
|
||||
// Special-case: signature/thought_signature validation errors are not transient, but may be fixed by
|
||||
@@ -552,6 +578,31 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if isGeminiSignatureRelatedError(respBody) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "signature_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
var strippedClaudeBody []byte
|
||||
stageName := ""
|
||||
switch signatureRetryStage {
|
||||
@@ -602,6 +653,31 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
if attempt < geminiMaxRetries {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
@@ -627,12 +703,64 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if tempMatched {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody)
|
||||
}
|
||||
|
||||
requestID := resp.Header.Get(requestIDHeader)
|
||||
@@ -855,8 +983,23 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = idHeader
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
// In this code path `body` is already the JSON sent to upstream.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
if attempt < geminiMaxRetries {
|
||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
sleepGeminiBackoff(attempt)
|
||||
@@ -874,7 +1017,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+sanitizeUpstreamErrorMessage(err.Error()))
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
|
||||
@@ -893,6 +1037,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
if attempt < geminiMaxRetries {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
upstreamReqID = resp.Header.Get("x-goog-request-id")
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "retry",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
|
||||
sleepGeminiBackoff(attempt)
|
||||
continue
|
||||
@@ -956,19 +1125,87 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(evBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, respBody)
|
||||
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("gemini upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
@@ -1070,7 +1307,33 @@ func sanitizeUpstreamErrorMessage(msg string) string {
|
||||
return sensitiveQueryParamRegex.ReplaceAllString(msg, `$1***`)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error {
|
||||
func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, upstreamStatus, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: upstreamStatus,
|
||||
UpstreamRequestID: upstreamRequestID,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
|
||||
}
|
||||
|
||||
var statusCode int
|
||||
var errType, errMsg string
|
||||
|
||||
@@ -1178,7 +1441,10 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, ups
|
||||
"type": "error",
|
||||
"error": gin.H{"type": errType, "message": errMsg},
|
||||
})
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
if upstreamMsg == "" {
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
}
|
||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||
}
|
||||
|
||||
type claudeErrorMapping struct {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -127,6 +128,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
|
||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -140,6 +144,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
|
||||
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -155,10 +162,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
|
||||
type mockGroupRepoForGemini struct {
|
||||
groups map[int64]*Group
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
@@ -251,6 +269,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
|
||||
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformGemini,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {ID: groupID, Platform: PlatformGemini},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
|
||||
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
||||
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
||||
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
||||
DeleteAccessToken(ctx context.Context, cacheKey string) error
|
||||
|
||||
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
||||
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("not a gemini oauth account")
|
||||
}
|
||||
|
||||
cacheKey := geminiTokenCacheKey(account)
|
||||
cacheKey := GeminiTokenCacheKey(account)
|
||||
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func geminiTokenCacheKey(account *Account) string {
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return projectID
|
||||
return "gemini:" + projectID
|
||||
}
|
||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
||||
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64
|
||||
@@ -10,6 +13,7 @@ type Group struct {
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
Status string
|
||||
Hydrated bool // indicates the group was loaded from a trusted repository source
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
@@ -26,6 +30,12 @@ type Group struct {
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
// 模型路由配置
|
||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||
// value: 优先账号 ID 列表
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -72,3 +82,58 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if group.ID <= 0 {
|
||||
return false
|
||||
}
|
||||
if !group.Hydrated {
|
||||
return false
|
||||
}
|
||||
if group.Platform == "" || group.Status == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
|
||||
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
|
||||
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
|
||||
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 精确匹配优先
|
||||
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
|
||||
return accountIDs
|
||||
}
|
||||
|
||||
// 2. 通配符匹配(前缀匹配)
|
||||
for pattern, accountIDs := range g.ModelRouting {
|
||||
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
|
||||
return accountIDs
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchModelPattern 检查模型是否匹配模式
|
||||
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
|
||||
func matchModelPattern(pattern, model string) bool {
|
||||
if pattern == model {
|
||||
return true
|
||||
}
|
||||
|
||||
// 处理 * 通配符(仅支持末尾通配符)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := strings.TrimSuffix(pattern, "*")
|
||||
return strings.HasPrefix(model, prefix)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *Group) error
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
GetByIDLite(ctx context.Context, id int64) (*Group, error)
|
||||
Update(ctx context.Context, group *Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
@@ -49,13 +50,15 @@ type UpdateGroupRequest struct {
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo GroupRepository
|
||||
groupRepo GroupRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo GroupRepository) *GroupService {
|
||||
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
groupRepo: groupRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("update group: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
@@ -166,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
|
||||
56
backend/internal/service/model_rate_limit.go
Normal file
56
backend/internal/service/model_rate_limit.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const modelRateLimitsKey = "model_rate_limits"
|
||||
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
|
||||
|
||||
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
|
||||
model := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
if strings.Contains(model, "sonnet") {
|
||||
return modelRateLimitScopeClaudeSonnet, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (a *Account) isModelRateLimited(requestedModel string) bool {
|
||||
scope, ok := resolveModelRateLimitScope(requestedModel)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
resetAt := a.modelRateLimitResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*resetAt)
|
||||
}
|
||||
|
||||
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
|
||||
if a == nil || a.Extra == nil || scope == "" {
|
||||
return nil
|
||||
}
|
||||
rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawLimit, ok := rawLimits[scope].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
|
||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||
return nil
|
||||
}
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
528
backend/internal/service/openai_codex_transform.go
Normal file
528
backend/internal/service/openai_codex_transform.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||
codexCacheTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
//go:embed prompts/codex_cli_instructions.md
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-high": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
"gpt-5.2-none": "gpt-5.2",
|
||||
"gpt-5.2-low": "gpt-5.2",
|
||||
"gpt-5.2-medium": "gpt-5.2",
|
||||
"gpt-5.2-high": "gpt-5.2",
|
||||
"gpt-5.2-xhigh": "gpt-5.2",
|
||||
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-low": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-medium": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-high": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1": "gpt-5.1",
|
||||
"gpt-5.1-none": "gpt-5.1",
|
||||
"gpt-5.1-low": "gpt-5.1",
|
||||
"gpt-5.1-medium": "gpt-5.1",
|
||||
"gpt-5.1-high": "gpt-5.1",
|
||||
"gpt-5.1-chat-latest": "gpt-5.1",
|
||||
"gpt-5-codex": "gpt-5.1-codex",
|
||||
"codex-mini-latest": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5": "gpt-5.1",
|
||||
"gpt-5-mini": "gpt-5.1",
|
||||
"gpt-5-nano": "gpt-5.1",
|
||||
}
|
||||
|
||||
type codexTransformResult struct {
|
||||
Modified bool
|
||||
NormalizedModel string
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
type opencodeCacheMetadata struct {
|
||||
ETag string `json:"etag"`
|
||||
LastFetch string `json:"lastFetch,omitempty"`
|
||||
LastChecked int64 `json:"lastChecked"`
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||
|
||||
model := ""
|
||||
if v, ok := reqBody["model"].(string); ok {
|
||||
model = v
|
||||
}
|
||||
normalizedModel := normalizeCodexModel(model)
|
||||
if normalizedModel != "" {
|
||||
if model != normalizedModel {
|
||||
reqBody["model"] = normalizedModel
|
||||
result.Modified = true
|
||||
}
|
||||
result.NormalizedModel = normalizedModel
|
||||
}
|
||||
|
||||
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
|
||||
// 避免上游返回 "Store must be set to false"。
|
||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||
reqBody["store"] = false
|
||||
result.Modified = true
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||
reqBody["stream"] = true
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
delete(reqBody, "max_output_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
if _, ok := reqBody["max_completion_tokens"]; ok {
|
||||
delete(reqBody, "max_completion_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if normalizeCodexTools(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
result.PromptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||
|
||||
if instructions != "" {
|
||||
if existingInstructions != instructions {
|
||||
reqBody["instructions"] = instructions
|
||||
result.Modified = true
|
||||
}
|
||||
} else if existingInstructions == "" {
|
||||
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions != "" {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
result.Modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||
if input, ok := reqBody["input"].([]any); ok {
|
||||
input = filterCodexInput(input, needsToolContinuation)
|
||||
reqBody["input"] = input
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
if model == "" {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
modelID := model
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(modelID)
|
||||
|
||||
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||
return "gpt-5.2-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
|
||||
return "gpt-5.1-codex-mini"
|
||||
}
|
||||
if strings.Contains(normalized, "codex-mini-latest") ||
|
||||
strings.Contains(normalized, "gpt-5-codex-mini") ||
|
||||
strings.Contains(normalized, "gpt 5 codex mini") {
|
||||
return "codex-mini-latest"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
if strings.Contains(normalized, "codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
func getNormalizedCodexModel(modelID string) string {
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
if mapped, ok := codexModelMap[modelID]; ok {
|
||||
return mapped
|
||||
}
|
||||
lower := strings.ToLower(modelID)
|
||||
for key, value := range codexModelMap {
|
||||
if strings.ToLower(key) == lower {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||
cacheDir := codexCachePath("")
|
||||
if cacheDir == "" {
|
||||
return ""
|
||||
}
|
||||
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||
|
||||
var cachedContent string
|
||||
if content, ok := readFile(cacheFile); ok {
|
||||
cachedContent = content
|
||||
}
|
||||
|
||||
var meta opencodeCacheMetadata
|
||||
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||
return cachedContent
|
||||
}
|
||||
}
|
||||
|
||||
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||
return cachedContent
|
||||
}
|
||||
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||
_ = writeFile(cacheFile, content)
|
||||
meta = opencodeCacheMetadata{
|
||||
ETag: etag,
|
||||
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||
LastChecked: time.Now().UnixMilli(),
|
||||
}
|
||||
_ = writeJSON(metaFile, meta)
|
||||
return content
|
||||
}
|
||||
|
||||
return cachedContent
|
||||
}
|
||||
|
||||
func getOpenCodeCodexHeader() string {
|
||||
// 优先从 opencode 仓库缓存获取指令。
|
||||
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||
|
||||
// 若 opencode 指令可用,直接返回。
|
||||
if opencodeInstructions != "" {
|
||||
return opencodeInstructions
|
||||
}
|
||||
|
||||
// 否则回退使用本地 Codex CLI 指令。
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
func getCodexCLIInstructions() string {
|
||||
return codexCLIInstructions
|
||||
}
|
||||
|
||||
func GetOpenCodeInstructions() string {
|
||||
return getOpenCodeCodexHeader()
|
||||
}
|
||||
|
||||
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
|
||||
func GetCodexCLIInstructions() string {
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) != codexInstructions {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
|
||||
func IsInstructionError(errorMessage string) bool {
|
||||
if errorMessage == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(errorMessage)
|
||||
instructionKeywords := []string{
|
||||
"instruction",
|
||||
"instructions",
|
||||
"system prompt",
|
||||
"system message",
|
||||
"invalid prompt",
|
||||
"prompt format",
|
||||
}
|
||||
|
||||
for _, keyword := range instructionKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
filtered := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
if typ == "item_reference" {
|
||||
if !preserveReferences {
|
||||
continue
|
||||
}
|
||||
newItem := make(map[string]any, len(m))
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
filtered = append(filtered, newItem)
|
||||
continue
|
||||
}
|
||||
|
||||
newItem := m
|
||||
copied := false
|
||||
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
|
||||
ensureCopy := func() {
|
||||
if copied {
|
||||
return
|
||||
}
|
||||
newItem = make(map[string]any, len(m))
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
copied = true
|
||||
}
|
||||
|
||||
if isCodexToolCallItemType(typ) {
|
||||
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
||||
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||
ensureCopy()
|
||||
newItem["call_id"] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !preserveReferences {
|
||||
ensureCopy()
|
||||
delete(newItem, "id")
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCodexToolCallItemType(typ string) bool {
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
|
||||
}
|
||||
|
||||
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for idx, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
toolType, _ := toolMap["type"].(string)
|
||||
if strings.TrimSpace(toolType) != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
function, ok := toolMap["function"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := toolMap["name"]; !ok {
|
||||
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
toolMap["name"] = name
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["description"]; !ok {
|
||||
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||
toolMap["description"] = desc
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["parameters"]; !ok {
|
||||
if params, ok := function["parameters"]; ok {
|
||||
toolMap["parameters"] = params
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["strict"]; !ok {
|
||||
if strict, ok := function["strict"]; ok {
|
||||
toolMap["strict"] = strict
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
|
||||
if modified {
|
||||
reqBody["tools"] = tools
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func codexCachePath(filename string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||
if filename == "" {
|
||||
return cacheDir
|
||||
}
|
||||
return filepath.Join(cacheDir, filename)
|
||||
}
|
||||
|
||||
func readFile(path string) (string, bool) {
|
||||
if path == "" {
|
||||
return "", false
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(data), true
|
||||
}
|
||||
|
||||
func writeFile(path, content string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty cache path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
func loadJSON(path string, target any) bool {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(data, target); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func writeJSON(path string, value any) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty json path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "sub2api-codex")
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", "", resp.StatusCode, err
|
||||
}
|
||||
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||
}
|
||||
167
backend/internal/service/openai_codex_transform_test.go
Normal file
167
backend/internal/service/openai_codex_transform_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
// 未显式设置 store=true,默认为 false。
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
// 校验 input[0] 为 map,避免断言失败导致测试中断。
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "item_reference", first["type"])
|
||||
require.Equal(t, "ref1", first["id"])
|
||||
|
||||
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "o1", second["id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"store": false,
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
// 显式 store=true 也会强制为 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"store": true,
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{
|
||||
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 1)
|
||||
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
|
||||
item, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasID := item["id"]
|
||||
require.False(t, hasID)
|
||||
}
|
||||
|
||||
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||
input := []any{
|
||||
map[string]any{"type": "item_reference", "id": "ref1"},
|
||||
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||
}
|
||||
|
||||
filtered := filterCodexInput(input, false)
|
||||
require.Len(t, filtered, 1)
|
||||
// 校验 filtered[0] 为 map,确保字段检查可靠。
|
||||
item, ok := filtered[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", item["type"])
|
||||
_, hasID := item["id"]
|
||||
require.False(t, hasID)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 0)
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// 使用临时 HOME 避免触发网络拉取 header。
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("HOME", tempDir)
|
||||
|
||||
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||
|
||||
meta := map[string]any{
|
||||
"etag": "",
|
||||
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||
"lastChecked": time.Now().UnixMilli(),
|
||||
}
|
||||
data, err := json.Marshal(meta)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -41,6 +42,7 @@ var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"content-type": true,
|
||||
"conversation_id": true,
|
||||
"user-agent": true,
|
||||
"originator": true,
|
||||
"session_id": true,
|
||||
@@ -84,12 +86,15 @@ type OpenAIGatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
@@ -100,12 +105,14 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
) *OpenAIGatewayService {
|
||||
return &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -114,12 +121,15 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +168,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
@@ -169,16 +179,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []Account
|
||||
var err error
|
||||
// 简易模式:忽略分组限制,查询所有可用账号
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
}
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
@@ -190,6 +191,11 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
@@ -300,7 +306,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
@@ -336,6 +342,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if isExcluded(acc.ID) {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
|
||||
// re-check schedulability here so recently rate-limited/overloaded accounts
|
||||
// are not selected again before the bucket is rebuilt.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -445,6 +457,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false)
|
||||
return accounts, err
|
||||
}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
@@ -467,6 +483,13 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
@@ -485,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case AccountTypeOAuth:
|
||||
// 使用 TokenProvider 获取缓存的 token
|
||||
if s.openAITokenProvider != nil {
|
||||
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
// 降级:TokenProvider 未配置时直接从账号读取
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
@@ -511,7 +543,7 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
@@ -528,33 +560,97 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
promptCacheKey := ""
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
// Apply model mapping
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API:
|
||||
// 1. Add store: false
|
||||
// 2. Normalize input format for Codex API compatibility
|
||||
if account.Type == AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
|
||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
||||
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
|
||||
// Codex API expects: {"content": "..."}
|
||||
if normalizeInputForCodexAPI(reqBody) {
|
||||
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
|
||||
if model, ok := reqBody["model"].(string); ok {
|
||||
normalizedModel := normalizeCodexModel(model)
|
||||
if normalizedModel != "" && normalizedModel != model {
|
||||
log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
|
||||
model, normalizedModel, account.Name, account.Type, isCodexCLI)
|
||||
reqBody["model"] = normalizedModel
|
||||
mappedModel = normalizedModel
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。
|
||||
if reasoning, ok := reqBody["reasoning"].(map[string]any); ok {
|
||||
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
|
||||
reasoning["effort"] = "none"
|
||||
bodyModified = true
|
||||
log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
||||
codexResult := applyCodexOAuthTransform(reqBody)
|
||||
if codexResult.Modified {
|
||||
bodyModified = true
|
||||
}
|
||||
if codexResult.NormalizedModel != "" {
|
||||
mappedModel = codexResult.NormalizedModel
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
}
|
||||
}
|
||||
|
||||
// Handle max_output_tokens based on platform and account type
|
||||
if !isCodexCLI {
|
||||
if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens {
|
||||
switch account.Platform {
|
||||
case PlatformOpenAI:
|
||||
// For OpenAI API Key, remove max_output_tokens (not supported)
|
||||
// For OpenAI OAuth (Responses API), keep it (supported)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
delete(reqBody, "max_output_tokens")
|
||||
bodyModified = true
|
||||
}
|
||||
case PlatformAnthropic:
|
||||
// For Anthropic (Claude), convert to max_tokens
|
||||
delete(reqBody, "max_output_tokens")
|
||||
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
|
||||
reqBody["max_tokens"] = maxOutputTokens
|
||||
}
|
||||
bodyModified = true
|
||||
case PlatformGemini:
|
||||
// For Gemini, remove (will be handled by Gemini-specific transform)
|
||||
delete(reqBody, "max_output_tokens")
|
||||
bodyModified = true
|
||||
default:
|
||||
// For unknown platforms, remove to be safe
|
||||
delete(reqBody, "max_output_tokens")
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// Also handle max_completion_tokens (similar logic)
|
||||
if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens {
|
||||
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
|
||||
delete(reqBody, "max_completion_tokens")
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
var err error
|
||||
@@ -571,7 +667,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -582,16 +678,63 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -632,7 +775,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
switch account.Type {
|
||||
@@ -672,12 +815,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
// Set accept header based on stream mode
|
||||
if isStream {
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("accept", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// Whitelist passthrough headers
|
||||
@@ -689,6 +826,19 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
}
|
||||
}
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
if isCodexCLI {
|
||||
req.Header.Set("originator", "codex_cli_rs")
|
||||
} else {
|
||||
req.Header.Set("originator", "opencode")
|
||||
}
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
if promptCacheKey != "" {
|
||||
req.Header.Set("conversation_id", promptCacheKey)
|
||||
req.Header.Set("session_id", promptCacheKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom User-Agent if configured
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
@@ -705,17 +855,53 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(body), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"OpenAI upstream error %d (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
}
|
||||
|
||||
// Check custom error codes
|
||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream gateway error",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
@@ -723,6 +909,20 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: kind,
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -761,7 +961,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
},
|
||||
})
|
||||
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
if upstreamMsg == "" {
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
// openaiStreamingResult streaming response result
|
||||
@@ -905,6 +1108,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||
line = "data: " + correctedData
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
@@ -933,6 +1141,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
}
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
@@ -988,6 +1200,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
||||
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
|
||||
if changed {
|
||||
return []byte(corrected)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
@@ -1016,6 +1242,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth {
|
||||
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
|
||||
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
|
||||
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse usage
|
||||
var response struct {
|
||||
Usage struct {
|
||||
@@ -1055,6 +1288,112 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func isEventStreamResponse(header http.Header) bool {
|
||||
contentType := strings.ToLower(header.Get("Content-Type"))
|
||||
return strings.Contains(contentType, "text/event-stream")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
bodyText := string(body)
|
||||
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
if ok {
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(finalResponse, &response); err == nil {
|
||||
usage.InputTokens = response.Usage.InputTokens
|
||||
usage.OutputTokens = response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
body = finalResponse
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
// Correct tool calls in final response
|
||||
body = s.correctToolCallsInResponseBody(body)
|
||||
} else {
|
||||
usage = s.parseSSEUsageFromBody(bodyText)
|
||||
if originalModel != mappedModel {
|
||||
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||
}
|
||||
body = []byte(bodyText)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := "application/json; charset=utf-8"
|
||||
if !ok {
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "text/event-stream"
|
||||
}
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &event) != nil {
|
||||
continue
|
||||
}
|
||||
if event.Type == "response.done" || event.Type == "response.completed" {
|
||||
if len(event.Response) > 0 {
|
||||
return event.Response, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
usage := &OpenAIUsage{}
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||
lines := strings.Split(body, "\n")
|
||||
for i, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
@@ -1094,101 +1433,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
return newBody
|
||||
}
|
||||
|
||||
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
|
||||
// that the ChatGPT internal Codex API expects.
|
||||
//
|
||||
// AI SDK sends content as an array of typed objects:
|
||||
//
|
||||
// {"content": [{"type": "input_text", "text": "hello"}]}
|
||||
//
|
||||
// ChatGPT Codex API expects content as a simple string:
|
||||
//
|
||||
// {"content": "hello"}
|
||||
//
|
||||
// This function modifies reqBody in-place and returns true if any modification was made.
|
||||
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
|
||||
input, ok := reqBody["input"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is a simple string (already compatible)
|
||||
if _, isString := input.(string); isString {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is an array of messages
|
||||
inputArray, ok := input.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, item := range inputArray {
|
||||
message, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := message["content"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is already a string, no conversion needed
|
||||
if _, isString := content.(string); isString {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is an array (AI SDK format), convert to string
|
||||
contentArray, ok := content.([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract text from content array
|
||||
var textParts []string
|
||||
for _, part := range contentArray {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle different content types
|
||||
partType, _ := partMap["type"].(string)
|
||||
switch partType {
|
||||
case "input_text", "text":
|
||||
// Extract text from input_text or text type
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
case "input_image", "image":
|
||||
// For images, we need to preserve the original format
|
||||
// as ChatGPT Codex API may support images in a different way
|
||||
// For now, skip image parts (they will be lost in conversion)
|
||||
// TODO: Consider preserving image data or handling it separately
|
||||
continue
|
||||
case "input_file", "file":
|
||||
// Similar to images, file inputs may need special handling
|
||||
continue
|
||||
default:
|
||||
// For unknown types, try to extract text if available
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert content array to string
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = strings.Join(textParts, "\n")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
@@ -1197,6 +1441,7 @@ type OpenAIRecordUsageInput struct {
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -1242,28 +1487,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
@@ -1271,6 +1518,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.UserAgent = &input.UserAgent
|
||||
}
|
||||
|
||||
// 添加 IPAddress
|
||||
if input.IPAddress != "" {
|
||||
usageLog.IPAddress = &input.IPAddress
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -15,6 +16,129 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type stubOpenAIAccountRepo struct {
|
||||
AccountRepository
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
groupID := int64(1)
|
||||
|
||||
rateLimited := Account{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
RateLimitResetAt: &resetAt,
|
||||
}
|
||||
available := Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection with account")
|
||||
}
|
||||
if selection.Account.ID != available.ID {
|
||||
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
groupID := int64(1)
|
||||
|
||||
rateLimited := Account{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
RateLimitResetAt: &resetAt,
|
||||
}
|
||||
available := Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
||||
// concurrencyService is nil, forcing the non-load-batch selection path.
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection with account")
|
||||
}
|
||||
if selection.Account.ID != available.ID {
|
||||
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -220,7 +344,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||
}
|
||||
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
|
||||
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
|
||||
// 创建一个简单的 service 实例来测试工具修正
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
changed bool
|
||||
}{
|
||||
{
|
||||
name: "correct apply_patch in response body",
|
||||
input: []byte(`{
|
||||
"choices": [{
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"function": {"name": "apply_patch"}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
}`),
|
||||
expected: "edit",
|
||||
changed: true,
|
||||
},
|
||||
{
|
||||
name: "correct update_plan in response body",
|
||||
input: []byte(`{
|
||||
"tool_calls": [{
|
||||
"function": {"name": "update_plan"}
|
||||
}]
|
||||
}`),
|
||||
expected: "todowrite",
|
||||
changed: true,
|
||||
},
|
||||
{
|
||||
name: "no change for correct tool name",
|
||||
input: []byte(`{
|
||||
"tool_calls": [{
|
||||
"function": {"name": "edit"}
|
||||
}]
|
||||
}`),
|
||||
expected: "edit",
|
||||
changed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := service.correctToolCallsInResponseBody(tt.input)
|
||||
resultStr := string(result)
|
||||
|
||||
// 检查是否包含期望的工具名称
|
||||
if !strings.Contains(resultStr, tt.expected) {
|
||||
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
|
||||
}
|
||||
|
||||
// 对于预期有变化的情况,验证结果与输入不同
|
||||
if tt.changed && string(result) == string(tt.input) {
|
||||
t.Error("expected result to be different from input, but they are the same")
|
||||
}
|
||||
|
||||
// 对于预期无变化的情况,验证结果与输入相同
|
||||
if !tt.changed && string(result) != string(tt.input) {
|
||||
t.Error("expected result to be same as input, but they are different")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
|
||||
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
if service.toolCorrector == nil {
|
||||
t.Fatal("toolCorrector should not be nil")
|
||||
}
|
||||
|
||||
// 测试修正器可以正常工作
|
||||
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
|
||||
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||
|
||||
if !changed {
|
||||
t.Error("expected tool call to be corrected")
|
||||
}
|
||||
|
||||
if !strings.Contains(corrected, "edit") {
|
||||
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCorrectionStats 测试工具修正统计功能
|
||||
func TestToolCorrectionStats(t *testing.T) {
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
// 执行几次修正
|
||||
testData := []string{
|
||||
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
|
||||
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||
}
|
||||
|
||||
stats := service.toolCorrector.GetStats()
|
||||
|
||||
if stats.TotalCorrected != 3 {
|
||||
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||
}
|
||||
}
|
||||
189
backend/internal/service/openai_token_provider.go
Normal file
189
backend/internal/service/openai_token_provider.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAITokenRefreshSkew = 3 * time.Minute
|
||||
openAITokenCacheSkew = 5 * time.Minute
|
||||
openAILockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type OpenAITokenCache = GeminiTokenCache
|
||||
|
||||
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
|
||||
type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
func NewOpenAITokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache OpenAITokenCache,
|
||||
openAIOAuthService *OpenAIOAuthService,
|
||||
) *OpenAITokenProvider {
|
||||
return &OpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
openAIOAuthService: openAIOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(openAILockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
810
backend/internal/service/openai_token_provider_test.go
Normal file
810
backend/internal/service/openai_token_provider_test.go
Normal file
@@ -0,0 +1,810 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// openAITokenCacheStub implements OpenAITokenCache for testing
|
||||
type openAITokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newOpenAITokenCacheStub() *openAITokenCacheStub {
|
||||
return &openAITokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
|
||||
type openAIAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
|
||||
type openAIOAuthServiceStub struct {
|
||||
tokenInfo *OpenAITokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
|
||||
now := time.Now()
|
||||
return map[string]any{
|
||||
"access_token": info.AccessToken,
|
||||
"refresh_token": info.RefreshToken,
|
||||
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
tokenInfo: &OpenAITokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// We need to directly test with the stub - create a custom provider
|
||||
customProvider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := customProvider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
// testOpenAITokenProvider is a test version that uses the stub OAuth service
|
||||
type testOpenAITokenProvider struct {
|
||||
accountRepo *openAIAccountRepoStub
|
||||
tokenCache *openAITokenCacheStub
|
||||
oauthService *openAIOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond) // Short wait for test
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > openAITokenCacheSkew {
|
||||
ttl = until - openAITokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get the token set by the "winner" or the original
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
tokenInfo: &OpenAITokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
|
||||
originalGet := int32(0)
|
||||
cache.tokens[cacheKey] = "" // Empty initially
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// In a goroutine, set the cached token after a small delay (simulating race)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get either the refreshed token or the cached one
|
||||
require.NotEmpty(t, token)
|
||||
_ = originalGet // Suppress unused warning
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get either the fallback token or the refreshed one
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 201,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set) - should use credentials
|
||||
account := &Account{
|
||||
ID: 202,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
// Without OAuth service, refresh will fail but token should be returned from credentials
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cacheKey := "openai:account:203"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 203,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token) // Should fall back to credentials
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 204,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 205,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 206,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
213
backend/internal/service/openai_tool_continuation.go
Normal file
213
backend/internal/service/openai_tool_continuation.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||
// 或显式声明 tools/tool_choice。
|
||||
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
if hasNonEmptyString(reqBody["previous_response_id"]) {
|
||||
return true
|
||||
}
|
||||
if hasToolsSignal(reqBody) {
|
||||
return true
|
||||
}
|
||||
if hasToolChoiceSignal(reqBody) {
|
||||
return true
|
||||
}
|
||||
if inputHasType(reqBody, "function_call_output") {
|
||||
return true
|
||||
}
|
||||
if inputHasType(reqBody, "item_reference") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
return inputHasType(reqBody, "function_call_output")
|
||||
}
|
||||
|
||||
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||
func HasToolCallContext(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ids := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
ids[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||
// 用于仅依赖引用项完成续链场景的校验。
|
||||
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||
if reqBody == nil || len(callIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
referenceIDs := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "item_reference" {
|
||||
continue
|
||||
}
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
if len(referenceIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||
func inputHasType(reqBody map[string]any, want string) bool {
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||
func hasNonEmptyString(value any) bool {
|
||||
stringValue, ok := value.(string)
|
||||
return ok && strings.TrimSpace(stringValue) != ""
|
||||
}
|
||||
|
||||
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
|
||||
func hasToolsSignal(reqBody map[string]any) bool {
|
||||
raw, exists := reqBody["tools"]
|
||||
if !exists || raw == nil {
|
||||
return false
|
||||
}
|
||||
if tools, ok := raw.([]any); ok {
|
||||
return len(tools) > 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
|
||||
func hasToolChoiceSignal(reqBody map[string]any) bool {
|
||||
raw, exists := reqBody["tool_choice"]
|
||||
if !exists || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(value) != ""
|
||||
case map[string]any:
|
||||
return len(value) > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
|
||||
cases := []struct {
|
||||
name string
|
||||
body map[string]any
|
||||
want bool
|
||||
}{
|
||||
{name: "nil", body: nil, want: false},
|
||||
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
|
||||
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
|
||||
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
|
||||
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
|
||||
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasFunctionCallOutput(t *testing.T) {
|
||||
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||
require.False(t, HasFunctionCallOutput(nil))
|
||||
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||
}))
|
||||
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": "text",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHasToolCallContext(t *testing.T) {
|
||||
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||
require.False(t, HasToolCallContext(nil))
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||
}))
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||
}))
|
||||
require.False(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_call"}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||
// 仅提取非空 call_id,去重后返回。
|
||||
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
})
|
||||
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||
}
|
||||
|
||||
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||
require.False(t, HasFunctionCallOutputMissingCallID(nil))
|
||||
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||
}))
|
||||
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHasItemReferenceForCallIDs(t *testing.T) {
|
||||
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
|
||||
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
|
||||
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
|
||||
req := map[string]any{
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||
map[string]any{"type": "item_reference", "id": "call_2"},
|
||||
},
|
||||
}
|
||||
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
|
||||
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
|
||||
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
|
||||
}
|
||||
307
backend/internal/service/openai_tool_corrector.go
Normal file
307
backend/internal/service/openai_tool_corrector.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||
var codexToolNameMapping = map[string]string{
|
||||
"apply_patch": "edit",
|
||||
"applyPatch": "edit",
|
||||
"update_plan": "todowrite",
|
||||
"updatePlan": "todowrite",
|
||||
"read_plan": "todoread",
|
||||
"readPlan": "todoread",
|
||||
"search_files": "grep",
|
||||
"searchFiles": "grep",
|
||||
"list_files": "glob",
|
||||
"listFiles": "glob",
|
||||
"read_file": "read",
|
||||
"readFile": "read",
|
||||
"write_file": "write",
|
||||
"writeFile": "write",
|
||||
"execute_bash": "bash",
|
||||
"executeBash": "bash",
|
||||
"exec_bash": "bash",
|
||||
"execBash": "bash",
|
||||
}
|
||||
|
||||
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
|
||||
type ToolCorrectionStats struct {
|
||||
TotalCorrected int `json:"total_corrected"`
|
||||
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
|
||||
}
|
||||
|
||||
// CodexToolCorrector 处理 Codex 工具调用的自动修正
|
||||
type CodexToolCorrector struct {
|
||||
stats ToolCorrectionStats
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCodexToolCorrector 创建新的工具修正器
|
||||
func NewCodexToolCorrector() *CodexToolCorrector {
|
||||
return &CodexToolCorrector{
|
||||
stats: ToolCorrectionStats{
|
||||
CorrectionsByTool: make(map[string]int),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
|
||||
// 返回修正后的数据和是否进行了修正
|
||||
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
|
||||
if data == "" || data == "\n" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// 尝试解析 JSON
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
// 不是有效的 JSON,直接返回原数据
|
||||
return data, false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 处理 tool_calls 数组
|
||||
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 function_call 对象
|
||||
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 delta.tool_calls
|
||||
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||
if choices, ok := payload["choices"].([]any); ok {
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]any); ok {
|
||||
// 处理 message 中的工具调用
|
||||
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理 delta 中的工具调用
|
||||
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// 序列化回 JSON
|
||||
correctedBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
return data, false
|
||||
}
|
||||
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||
corrected := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(function) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return corrected
|
||||
}
|
||||
|
||||
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||
name, ok := functionCall["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 查找并修正工具名称
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
functionCall["name"] = correctName
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
}
|
||||
|
||||
// 修正工具参数(基于工具名称)
|
||||
if c.correctToolParameters(name, functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
|
||||
return corrected
|
||||
}
|
||||
|
||||
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||
arguments, ok := functionCall["arguments"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// arguments 可能是字符串(JSON)或已解析的 map
|
||||
var argsMap map[string]any
|
||||
switch v := arguments.(type) {
|
||||
case string:
|
||||
// 解析 JSON 字符串
|
||||
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||
return false
|
||||
}
|
||||
case map[string]any:
|
||||
argsMap = v
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// 移除 workdir 参数(OpenCode 不支持)
|
||||
if _, exists := argsMap["workdir"]; exists {
|
||||
delete(argsMap, "workdir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||
}
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||
}
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||
// 这里可以添加参数名称的映射逻辑
|
||||
if _, exists := argsMap["file_path"]; !exists {
|
||||
if path, exists := argsMap["path"]; exists {
|
||||
argsMap["file_path"] = path
|
||||
delete(argsMap, "path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果修正了参数,需要重新序列化
|
||||
if corrected {
|
||||
if _, wasString := arguments.(string); wasString {
|
||||
// 原本是字符串,序列化回字符串
|
||||
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||
functionCall["arguments"] = string(newArgsJSON)
|
||||
}
|
||||
} else {
|
||||
// 原本是 map,直接赋值
|
||||
functionCall["arguments"] = argsMap
|
||||
}
|
||||
}
|
||||
|
||||
return corrected
|
||||
}
|
||||
|
||||
// recordCorrection 记录一次工具名称修正
|
||||
func (c *CodexToolCorrector) recordCorrection(from, to string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.stats.TotalCorrected++
|
||||
key := fmt.Sprintf("%s->%s", from, to)
|
||||
c.stats.CorrectionsByTool[key]++
|
||||
|
||||
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||
from, to, c.stats.TotalCorrected)
|
||||
}
|
||||
|
||||
// GetStats 获取工具修正统计信息
|
||||
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// 返回副本以避免并发问题
|
||||
statsCopy := ToolCorrectionStats{
|
||||
TotalCorrected: c.stats.TotalCorrected,
|
||||
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
|
||||
}
|
||||
for k, v := range c.stats.CorrectionsByTool {
|
||||
statsCopy.CorrectionsByTool[k] = v
|
||||
}
|
||||
|
||||
return statsCopy
|
||||
}
|
||||
|
||||
// ResetStats 重置统计信息
|
||||
func (c *CodexToolCorrector) ResetStats() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.stats.TotalCorrected = 0
|
||||
c.stats.CorrectionsByTool = make(map[string]int)
|
||||
}
|
||||
|
||||
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
|
||||
func CorrectToolName(name string) (string, bool) {
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
return correctName, true
|
||||
}
|
||||
return name, false
|
||||
}
|
||||
|
||||
// GetToolNameMapping 获取工具名称映射表
|
||||
func GetToolNameMapping() map[string]string {
|
||||
// 返回副本以避免外部修改
|
||||
mapping := make(map[string]string, len(codexToolNameMapping))
|
||||
for k, v := range codexToolNameMapping {
|
||||
mapping[k] = v
|
||||
}
|
||||
return mapping
|
||||
}
|
||||
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectCorrected bool
|
||||
checkFunc func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "newline only",
|
||||
input: "\n",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
input: "not a json",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "correct apply_patch in tool_calls",
|
||||
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls, ok := payload["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in result")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct update_plan in function_call",
|
||||
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
functionCall, ok := payload["function_call"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function_call format")
|
||||
}
|
||||
if functionCall["name"] != "todowrite" {
|
||||
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct search_files in delta.tool_calls",
|
||||
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
delta, ok := payload["delta"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid delta format")
|
||||
}
|
||||
toolCalls, ok := delta["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in delta")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "grep" {
|
||||
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct list_files in choices.message.tool_calls",
|
||||
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
t.Fatal("No choices found in result")
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid choice format")
|
||||
}
|
||||
message, ok := choice["message"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid message format")
|
||||
}
|
||||
toolCalls, ok := message["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in message")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "glob" {
|
||||
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no correction needed",
|
||||
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "correct multiple tool calls",
|
||||
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls, ok := payload["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) < 2 {
|
||||
t.Fatal("Expected at least 2 tool_calls")
|
||||
}
|
||||
|
||||
toolCall1, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid first tool_call format")
|
||||
}
|
||||
func1, ok := toolCall1["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid first function format")
|
||||
}
|
||||
if func1["name"] != "edit" {
|
||||
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
|
||||
}
|
||||
|
||||
toolCall2, ok := toolCalls[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid second tool_call format")
|
||||
}
|
||||
func2, ok := toolCall2["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid second function format")
|
||||
}
|
||||
if func2["name"] != "read" {
|
||||
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "camelCase format - applyPatch",
|
||||
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls, ok := payload["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in result")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||
|
||||
if corrected != tt.expectCorrected {
|
||||
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
|
||||
}
|
||||
|
||||
if !corrected && result != tt.input {
|
||||
t.Errorf("Expected unchanged result when not corrected")
|
||||
}
|
||||
|
||||
if tt.checkFunc != nil {
|
||||
tt.checkFunc(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectToolName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
corrected bool
|
||||
}{
|
||||
{"apply_patch", "edit", true},
|
||||
{"applyPatch", "edit", true},
|
||||
{"update_plan", "todowrite", true},
|
||||
{"updatePlan", "todowrite", true},
|
||||
{"read_plan", "todoread", true},
|
||||
{"readPlan", "todoread", true},
|
||||
{"search_files", "grep", true},
|
||||
{"searchFiles", "grep", true},
|
||||
{"list_files", "glob", true},
|
||||
{"listFiles", "glob", true},
|
||||
{"read_file", "read", true},
|
||||
{"readFile", "read", true},
|
||||
{"write_file", "write", true},
|
||||
{"writeFile", "write", true},
|
||||
{"execute_bash", "bash", true},
|
||||
{"executeBash", "bash", true},
|
||||
{"exec_bash", "bash", true},
|
||||
{"execBash", "bash", true},
|
||||
{"unknown_tool", "unknown_tool", false},
|
||||
{"read", "read", false},
|
||||
{"edit", "edit", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result, corrected := CorrectToolName(tt.input)
|
||||
|
||||
if corrected != tt.corrected {
|
||||
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetToolNameMapping(t *testing.T) {
|
||||
mapping := GetToolNameMapping()
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"apply_patch": "edit",
|
||||
"update_plan": "todowrite",
|
||||
"read_plan": "todoread",
|
||||
"search_files": "grep",
|
||||
"list_files": "glob",
|
||||
}
|
||||
|
||||
for from, to := range expectedMappings {
|
||||
if mapping[from] != to {
|
||||
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
|
||||
}
|
||||
}
|
||||
|
||||
mapping["test_tool"] = "test_value"
|
||||
newMapping := GetToolNameMapping()
|
||||
if _, exists := newMapping["test_tool"]; exists {
|
||||
t.Error("Modifications to returned mapping should not affect original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectorStats(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
stats := corrector.GetStats()
|
||||
if stats.TotalCorrected != 0 {
|
||||
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
|
||||
}
|
||||
if len(stats.CorrectionsByTool) != 0 {
|
||||
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
|
||||
}
|
||||
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
|
||||
|
||||
stats = corrector.GetStats()
|
||||
if stats.TotalCorrected != 3 {
|
||||
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||
}
|
||||
|
||||
corrector.ResetStats()
|
||||
stats = corrector.GetStats()
|
||||
if stats.TotalCorrected != 0 {
|
||||
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
|
||||
}
|
||||
if len(stats.CorrectionsByTool) != 0 {
|
||||
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
input := `{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-5.1-codex",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"file\":\"test.go\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, corrected := corrector.CorrectToolCallsInSSEData(input)
|
||||
|
||||
if !corrected {
|
||||
t.Error("Expected data to be corrected")
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
t.Fatal("No choices found in result")
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid choice format")
|
||||
}
|
||||
delta, ok := choice["delta"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid delta format")
|
||||
}
|
||||
toolCalls, ok := delta["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in delta")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
function, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
|
||||
if function["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCorrectToolParameters 测试工具参数修正
|
||||
func TestCorrectToolParameters(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||
}{
|
||||
{
|
||||
name: "remove workdir from bash tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"command": true,
|
||||
"workdir": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rename path to file_path in edit tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"file_path": true,
|
||||
"path": false,
|
||||
"old_string": true,
|
||||
"new_string": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||
if !changed {
|
||||
t.Error("expected data to be corrected")
|
||||
}
|
||||
|
||||
// 解析修正后的数据
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
|
||||
t.Fatalf("failed to parse corrected data: %v", err)
|
||||
}
|
||||
|
||||
// 检查工具调用
|
||||
toolCalls, ok := result["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("no tool_calls found in corrected data")
|
||||
}
|
||||
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("invalid tool_call structure")
|
||||
}
|
||||
|
||||
function, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("no function found in tool_call")
|
||||
}
|
||||
|
||||
argumentsStr, ok := function["arguments"].(string)
|
||||
if !ok {
|
||||
t.Fatal("arguments is not a string")
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
|
||||
t.Fatalf("failed to parse arguments: %v", err)
|
||||
}
|
||||
|
||||
// 验证期望的参数
|
||||
for param, shouldExist := range tt.expected {
|
||||
_, exists := args[param]
|
||||
if shouldExist && !exists {
|
||||
t.Errorf("expected parameter %q to exist, but it doesn't", param)
|
||||
}
|
||||
if !shouldExist && exists {
|
||||
t.Errorf("expected parameter %q to not exist, but it does", param)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
194
backend/internal/service/ops_account_availability.go
Normal file
194
backend/internal/service/ops_account_availability.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetAccountAvailabilityStats returns current account availability stats.
|
||||
//
|
||||
// Query-level filtering is intentionally limited to platform/group to match the dashboard scope.
|
||||
func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFilter string, groupIDFilter *int64) (
|
||||
map[string]*PlatformAvailability,
|
||||
map[int64]*GroupAvailability,
|
||||
map[int64]*AccountAvailability,
|
||||
*time.Time,
|
||||
error,
|
||||
) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp != nil && grp.ID == *groupIDFilter {
|
||||
filtered = append(filtered, acc)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
accounts = filtered
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
collectedAt := now
|
||||
|
||||
platform := make(map[string]*PlatformAvailability)
|
||||
group := make(map[int64]*GroupAvailability)
|
||||
account := make(map[int64]*AccountAvailability)
|
||||
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
isTempUnsched := false
|
||||
if acc.TempUnschedulableUntil != nil && now.Before(*acc.TempUnschedulableUntil) {
|
||||
isTempUnsched = true
|
||||
}
|
||||
|
||||
isRateLimited := acc.RateLimitResetAt != nil && now.Before(*acc.RateLimitResetAt)
|
||||
isOverloaded := acc.OverloadUntil != nil && now.Before(*acc.OverloadUntil)
|
||||
hasError := acc.Status == StatusError
|
||||
|
||||
// Normalize exclusive status flags so the UI doesn't show conflicting badges.
|
||||
if hasError {
|
||||
isRateLimited = false
|
||||
isOverloaded = false
|
||||
}
|
||||
|
||||
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||
|
||||
if acc.Platform != "" {
|
||||
if _, ok := platform[acc.Platform]; !ok {
|
||||
platform[acc.Platform] = &PlatformAvailability{
|
||||
Platform: acc.Platform,
|
||||
}
|
||||
}
|
||||
p := platform[acc.Platform]
|
||||
p.TotalAccounts++
|
||||
if isAvailable {
|
||||
p.AvailableCount++
|
||||
}
|
||||
if isRateLimited {
|
||||
p.RateLimitCount++
|
||||
}
|
||||
if hasError {
|
||||
p.ErrorCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := group[grp.ID]; !ok {
|
||||
group[grp.ID] = &GroupAvailability{
|
||||
GroupID: grp.ID,
|
||||
GroupName: grp.Name,
|
||||
Platform: grp.Platform,
|
||||
}
|
||||
}
|
||||
g := group[grp.ID]
|
||||
g.TotalAccounts++
|
||||
if isAvailable {
|
||||
g.AvailableCount++
|
||||
}
|
||||
if isRateLimited {
|
||||
g.RateLimitCount++
|
||||
}
|
||||
if hasError {
|
||||
g.ErrorCount++
|
||||
}
|
||||
}
|
||||
|
||||
displayGroupID := int64(0)
|
||||
displayGroupName := ""
|
||||
if len(acc.Groups) > 0 && acc.Groups[0] != nil {
|
||||
displayGroupID = acc.Groups[0].ID
|
||||
displayGroupName = acc.Groups[0].Name
|
||||
}
|
||||
|
||||
item := &AccountAvailability{
|
||||
AccountID: acc.ID,
|
||||
AccountName: acc.Name,
|
||||
Platform: acc.Platform,
|
||||
GroupID: displayGroupID,
|
||||
GroupName: displayGroupName,
|
||||
Status: acc.Status,
|
||||
|
||||
IsAvailable: isAvailable,
|
||||
IsRateLimited: isRateLimited,
|
||||
IsOverloaded: isOverloaded,
|
||||
HasError: hasError,
|
||||
|
||||
ErrorMessage: acc.ErrorMessage,
|
||||
}
|
||||
|
||||
if isRateLimited && acc.RateLimitResetAt != nil {
|
||||
item.RateLimitResetAt = acc.RateLimitResetAt
|
||||
remainingSec := int64(time.Until(*acc.RateLimitResetAt).Seconds())
|
||||
if remainingSec > 0 {
|
||||
item.RateLimitRemainingSec = &remainingSec
|
||||
}
|
||||
}
|
||||
if isOverloaded && acc.OverloadUntil != nil {
|
||||
item.OverloadUntil = acc.OverloadUntil
|
||||
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||
if remainingSec > 0 {
|
||||
item.OverloadRemainingSec = &remainingSec
|
||||
}
|
||||
}
|
||||
if isTempUnsched && acc.TempUnschedulableUntil != nil {
|
||||
item.TempUnschedulableUntil = acc.TempUnschedulableUntil
|
||||
}
|
||||
|
||||
account[acc.ID] = item
|
||||
}
|
||||
|
||||
return platform, group, account, &collectedAt, nil
|
||||
}
|
||||
|
||||
type OpsAccountAvailability struct {
|
||||
Group *GroupAvailability
|
||||
Accounts map[int64]*AccountAvailability
|
||||
CollectedAt *time.Time
|
||||
}
|
||||
|
||||
func (s *OpsService) GetAccountAvailability(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("ops service is nil")
|
||||
}
|
||||
|
||||
if s.getAccountAvailability != nil {
|
||||
return s.getAccountAvailability(ctx, platformFilter, groupIDFilter)
|
||||
}
|
||||
|
||||
_, groupStats, accountStats, collectedAt, err := s.GetAccountAvailabilityStats(ctx, platformFilter, groupIDFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var group *GroupAvailability
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
group = groupStats[*groupIDFilter]
|
||||
}
|
||||
|
||||
if accountStats == nil {
|
||||
accountStats = map[int64]*AccountAvailability{}
|
||||
}
|
||||
|
||||
return &OpsAccountAvailability{
|
||||
Group: group,
|
||||
Accounts: accountStats,
|
||||
CollectedAt: collectedAt,
|
||||
}, nil
|
||||
}
|
||||
46
backend/internal/service/ops_advisory_lock.go
Normal file
46
backend/internal/service/ops_advisory_lock.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func hashAdvisoryLockID(key string) int64 {
|
||||
h := fnv.New64a()
|
||||
_, _ = h.Write([]byte(key))
|
||||
return int64(h.Sum64())
|
||||
}
|
||||
|
||||
func tryAcquireDBAdvisoryLock(ctx context.Context, db *sql.DB, lockID int64) (func(), bool) {
|
||||
if db == nil {
|
||||
return nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
acquired := false
|
||||
if err := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, false
|
||||
}
|
||||
if !acquired {
|
||||
_ = conn.Close()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
release := func() {
|
||||
unlockCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, _ = conn.ExecContext(unlockCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||
_ = conn.Close()
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
448
backend/internal/service/ops_aggregation_service.go
Normal file
448
backend/internal/service/ops_aggregation_service.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAggHourlyJobName = "ops_preaggregation_hourly"
|
||||
opsAggDailyJobName = "ops_preaggregation_daily"
|
||||
|
||||
opsAggHourlyInterval = 10 * time.Minute
|
||||
opsAggDailyInterval = 1 * time.Hour
|
||||
|
||||
// Keep in sync with ops retention target (vNext default 30d).
|
||||
opsAggBackfillWindow = 30 * 24 * time.Hour
|
||||
|
||||
// Recompute overlap to absorb late-arriving rows near boundaries.
|
||||
opsAggHourlyOverlap = 2 * time.Hour
|
||||
opsAggDailyOverlap = 48 * time.Hour
|
||||
|
||||
opsAggHourlyChunk = 24 * time.Hour
|
||||
opsAggDailyChunk = 7 * 24 * time.Hour
|
||||
|
||||
// Delay around boundaries (e.g. 10:00..10:05) to avoid aggregating buckets
|
||||
// that may still receive late inserts.
|
||||
opsAggSafeDelay = 5 * time.Minute
|
||||
|
||||
opsAggMaxQueryTimeout = 3 * time.Second
|
||||
opsAggHourlyTimeout = 5 * time.Minute
|
||||
opsAggDailyTimeout = 2 * time.Minute
|
||||
|
||||
opsAggHourlyLeaderLockKey = "ops:aggregation:hourly:leader"
|
||||
opsAggDailyLeaderLockKey = "ops:aggregation:daily:leader"
|
||||
|
||||
opsAggHourlyLeaderLockTTL = 15 * time.Minute
|
||||
opsAggDailyLeaderLockTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// OpsAggregationService periodically backfills ops_metrics_hourly / ops_metrics_daily
|
||||
// for stable long-window dashboard queries.
|
||||
//
|
||||
// It is safe to run in multi-replica deployments when Redis is available (leader lock).
|
||||
type OpsAggregationService struct {
|
||||
opsRepo OpsRepository
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
|
||||
db *sql.DB
|
||||
redisClient *redis.Client
|
||||
instanceID string
|
||||
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
hourlyMu sync.Mutex
|
||||
dailyMu sync.Mutex
|
||||
|
||||
skipLogMu sync.Mutex
|
||||
skipLogAt time.Time
|
||||
}
|
||||
|
||||
func NewOpsAggregationService(
|
||||
opsRepo OpsRepository,
|
||||
settingRepo SettingRepository,
|
||||
db *sql.DB,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsAggregationService {
|
||||
return &OpsAggregationService{
|
||||
opsRepo: opsRepo,
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
instanceID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
if s.stopCh == nil {
|
||||
s.stopCh = make(chan struct{})
|
||||
}
|
||||
go s.hourlyLoop()
|
||||
go s.dailyLoop()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stopCh != nil {
|
||||
close(s.stopCh)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) hourlyLoop() {
|
||||
// First run immediately.
|
||||
s.aggregateHourly()
|
||||
|
||||
ticker := time.NewTicker(opsAggHourlyInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.aggregateHourly()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) dailyLoop() {
|
||||
// First run immediately.
|
||||
s.aggregateDaily()
|
||||
|
||||
ticker := time.NewTicker(opsAggDailyInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.aggregateDaily()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) aggregateHourly() {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil {
|
||||
if !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Ops.Aggregation.Enabled {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAggHourlyTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !s.isMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx, opsAggHourlyLeaderLockKey, opsAggHourlyLeaderLockTTL, "[OpsAggregation][hourly]")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
s.hourlyMu.Lock()
|
||||
defer s.hourlyMu.Unlock()
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
// Aggregate stable full hours only.
|
||||
end := utcFloorToHour(time.Now().UTC().Add(-opsAggSafeDelay))
|
||||
start := end.Add(-opsAggBackfillWindow)
|
||||
|
||||
// Resume from the latest bucket with overlap.
|
||||
{
|
||||
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
|
||||
latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax)
|
||||
cancelMax()
|
||||
if err != nil {
|
||||
log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err)
|
||||
} else if ok {
|
||||
candidate := latest.Add(-opsAggHourlyOverlap)
|
||||
if candidate.After(start) {
|
||||
start = candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start = utcFloorToHour(start)
|
||||
if !start.Before(end) {
|
||||
return
|
||||
}
|
||||
|
||||
var aggErr error
|
||||
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggHourlyChunk) {
|
||||
chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end)
|
||||
if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||
aggErr = err
|
||||
log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finishedAt := time.Now().UTC()
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
dur := durationMs
|
||||
|
||||
if aggErr != nil {
|
||||
msg := truncateString(aggErr.Error(), 2048)
|
||||
errAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggHourlyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &errAt,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggHourlyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
LastResult: &result,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) aggregateDaily() {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil {
|
||||
if !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Ops.Aggregation.Enabled {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAggDailyTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !s.isMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx, opsAggDailyLeaderLockKey, opsAggDailyLeaderLockTTL, "[OpsAggregation][daily]")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
s.dailyMu.Lock()
|
||||
defer s.dailyMu.Unlock()
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
end := utcFloorToDay(time.Now().UTC())
|
||||
start := end.Add(-opsAggBackfillWindow)
|
||||
|
||||
{
|
||||
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
|
||||
latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax)
|
||||
cancelMax()
|
||||
if err != nil {
|
||||
log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err)
|
||||
} else if ok {
|
||||
candidate := latest.Add(-opsAggDailyOverlap)
|
||||
if candidate.After(start) {
|
||||
start = candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start = utcFloorToDay(start)
|
||||
if !start.Before(end) {
|
||||
return
|
||||
}
|
||||
|
||||
var aggErr error
|
||||
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggDailyChunk) {
|
||||
chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end)
|
||||
if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||
aggErr = err
|
||||
log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finishedAt := time.Now().UTC()
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
dur := durationMs
|
||||
|
||||
if aggErr != nil {
|
||||
msg := truncateString(aggErr.Error(), 2048)
|
||||
errAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggDailyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &errAt,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggDailyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
LastResult: &result,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) isMonitoringEnabled(ctx context.Context) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return false
|
||||
}
|
||||
if s.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var opsAggReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
func (s *OpsAggregationService) tryAcquireLeaderLock(ctx context.Context, key string, ttl time.Duration, logPrefix string) (func(), bool) {
|
||||
if s == nil {
|
||||
return nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Prefer Redis leader lock when available (multi-instance), but avoid stampeding
|
||||
// the DB when Redis is flaky by falling back to a DB advisory lock.
|
||||
if s.redisClient != nil {
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err == nil {
|
||||
if !ok {
|
||||
s.maybeLogSkip(logPrefix)
|
||||
return nil, false
|
||||
}
|
||||
release := func() {
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, _ = opsAggReleaseScript.Run(ctx2, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
// Redis error: fall through to DB advisory lock.
|
||||
}
|
||||
|
||||
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
|
||||
if !ok {
|
||||
s.maybeLogSkip(logPrefix)
|
||||
return nil, false
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) maybeLogSkip(prefix string) {
|
||||
s.skipLogMu.Lock()
|
||||
defer s.skipLogMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < time.Minute {
|
||||
return
|
||||
}
|
||||
s.skipLogAt = now
|
||||
if prefix == "" {
|
||||
prefix = "[OpsAggregation]"
|
||||
}
|
||||
log.Printf("%s leader lock held by another instance; skipping", prefix)
|
||||
}
|
||||
|
||||
func utcFloorToHour(t time.Time) time.Time {
|
||||
return t.UTC().Truncate(time.Hour)
|
||||
}
|
||||
|
||||
func utcFloorToDay(t time.Time) time.Time {
|
||||
u := t.UTC()
|
||||
y, m, d := u.Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func minTime(a, b time.Time) time.Time {
|
||||
if a.Before(b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
944
backend/internal/service/ops_alert_evaluator_service.go
Normal file
944
backend/internal/service/ops_alert_evaluator_service.go
Normal file
@@ -0,0 +1,944 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAlertEvaluatorJobName = "ops_alert_evaluator"
|
||||
|
||||
opsAlertEvaluatorTimeout = 45 * time.Second
|
||||
opsAlertEvaluatorLeaderLockKey = "ops:alert:evaluator:leader"
|
||||
opsAlertEvaluatorLeaderLockTTL = 90 * time.Second
|
||||
opsAlertEvaluatorSkipLogInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
var opsAlertEvaluatorReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
type OpsAlertEvaluatorService struct {
|
||||
opsService *OpsService
|
||||
opsRepo OpsRepository
|
||||
emailService *EmailService
|
||||
|
||||
redisClient *redis.Client
|
||||
cfg *config.Config
|
||||
instanceID string
|
||||
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
mu sync.Mutex
|
||||
ruleStates map[int64]*opsAlertRuleState
|
||||
|
||||
emailLimiter *slidingWindowLimiter
|
||||
|
||||
skipLogMu sync.Mutex
|
||||
skipLogAt time.Time
|
||||
|
||||
warnNoRedisOnce sync.Once
|
||||
}
|
||||
|
||||
type opsAlertRuleState struct {
|
||||
LastEvaluatedAt time.Time
|
||||
ConsecutiveBreaches int
|
||||
}
|
||||
|
||||
func NewOpsAlertEvaluatorService(
|
||||
opsService *OpsService,
|
||||
opsRepo OpsRepository,
|
||||
emailService *EmailService,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsAlertEvaluatorService {
|
||||
return &OpsAlertEvaluatorService{
|
||||
opsService: opsService,
|
||||
opsRepo: opsRepo,
|
||||
emailService: emailService,
|
||||
redisClient: redisClient,
|
||||
cfg: cfg,
|
||||
instanceID: uuid.NewString(),
|
||||
ruleStates: map[int64]*opsAlertRuleState{},
|
||||
emailLimiter: newSlidingWindowLimiter(0, time.Hour),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
if s.stopCh == nil {
|
||||
s.stopCh = make(chan struct{})
|
||||
}
|
||||
go s.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stopCh != nil {
|
||||
close(s.stopCh)
|
||||
}
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) run() {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Start immediately to produce early feedback in ops dashboard.
|
||||
timer := time.NewTimer(0)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
interval := s.getInterval()
|
||||
s.evaluateOnce(interval)
|
||||
timer.Reset(interval)
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) getInterval() time.Duration {
|
||||
// Default.
|
||||
interval := 60 * time.Second
|
||||
|
||||
if s == nil || s.opsService == nil {
|
||||
return interval
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg, err := s.opsService.GetOpsAlertRuntimeSettings(ctx)
|
||||
if err != nil || cfg == nil {
|
||||
return interval
|
||||
}
|
||||
if cfg.EvaluationIntervalSeconds <= 0 {
|
||||
return interval
|
||||
}
|
||||
if cfg.EvaluationIntervalSeconds < 1 {
|
||||
return interval
|
||||
}
|
||||
if cfg.EvaluationIntervalSeconds > int((24 * time.Hour).Seconds()) {
|
||||
return interval
|
||||
}
|
||||
return time.Duration(cfg.EvaluationIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAlertEvaluatorTimeout)
|
||||
defer cancel()
|
||||
|
||||
if s.opsService != nil && !s.opsService.IsMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
runtimeCfg := defaultOpsAlertRuntimeSettings()
|
||||
if s.opsService != nil {
|
||||
if loaded, err := s.opsService.GetOpsAlertRuntimeSettings(ctx); err == nil && loaded != nil {
|
||||
runtimeCfg = loaded
|
||||
}
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx, runtimeCfg.DistributedLock)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
rules, err := s.opsRepo.ListAlertRules(ctx)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
log.Printf("[OpsAlertEvaluator] list rules failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rulesTotal := len(rules)
|
||||
rulesEnabled := 0
|
||||
rulesEvaluated := 0
|
||||
eventsCreated := 0
|
||||
eventsResolved := 0
|
||||
emailsSent := 0
|
||||
|
||||
now := time.Now().UTC()
|
||||
safeEnd := now.Truncate(time.Minute)
|
||||
if safeEnd.IsZero() {
|
||||
safeEnd = now
|
||||
}
|
||||
|
||||
systemMetrics, _ := s.opsRepo.GetLatestSystemMetrics(ctx, 1)
|
||||
|
||||
// Cleanup stale state for removed rules.
|
||||
s.pruneRuleStates(rules)
|
||||
|
||||
for _, rule := range rules {
|
||||
if rule == nil || !rule.Enabled || rule.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
rulesEnabled++
|
||||
|
||||
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
|
||||
|
||||
windowMinutes := rule.WindowMinutes
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
windowStart := safeEnd.Add(-time.Duration(windowMinutes) * time.Minute)
|
||||
windowEnd := safeEnd
|
||||
|
||||
metricValue, ok := s.computeRuleMetric(ctx, rule, systemMetrics, windowStart, windowEnd, scopePlatform, scopeGroupID)
|
||||
if !ok {
|
||||
s.resetRuleState(rule.ID, now)
|
||||
continue
|
||||
}
|
||||
rulesEvaluated++
|
||||
|
||||
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
|
||||
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
|
||||
consecutive := s.updateRuleBreaches(rule.ID, now, interval, breachedNow)
|
||||
|
||||
activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if breachedNow && consecutive >= required {
|
||||
if activeEvent != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Scoped silencing: if a matching silence exists, skip creating a firing event.
|
||||
if s.opsService != nil {
|
||||
platform := strings.TrimSpace(scopePlatform)
|
||||
region := scopeRegion
|
||||
if platform != "" {
|
||||
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
if latestEvent != nil && rule.CooldownMinutes > 0 {
|
||||
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
|
||||
if now.Sub(latestEvent.FiredAt) < cooldown {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
firedEvent := &OpsAlertEvent{
|
||||
RuleID: rule.ID,
|
||||
Severity: strings.TrimSpace(rule.Severity),
|
||||
Status: OpsAlertStatusFiring,
|
||||
Title: fmt.Sprintf("%s: %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name)),
|
||||
Description: buildOpsAlertDescription(rule, metricValue, windowMinutes, scopePlatform, scopeGroupID),
|
||||
MetricValue: float64Ptr(metricValue),
|
||||
ThresholdValue: float64Ptr(rule.Threshold),
|
||||
Dimensions: buildOpsAlertDimensions(scopePlatform, scopeGroupID),
|
||||
FiredAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsCreated++
|
||||
if created != nil && created.ID > 0 {
|
||||
if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
|
||||
emailsSent++
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Not breached: resolve active event if present.
|
||||
if activeEvent != nil {
|
||||
resolvedAt := now
|
||||
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||
} else {
|
||||
eventsResolved++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
live := map[int64]struct{}{}
|
||||
for _, r := range rules {
|
||||
if r != nil && r.ID > 0 {
|
||||
live[r.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for id := range s.ruleStates {
|
||||
if _, ok := live[id]; !ok {
|
||||
delete(s.ruleStates, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) resetRuleState(ruleID int64, now time.Time) {
|
||||
if ruleID <= 0 {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
state, ok := s.ruleStates[ruleID]
|
||||
if !ok {
|
||||
state = &opsAlertRuleState{}
|
||||
s.ruleStates[ruleID] = state
|
||||
}
|
||||
state.LastEvaluatedAt = now
|
||||
state.ConsecutiveBreaches = 0
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) updateRuleBreaches(ruleID int64, now time.Time, interval time.Duration, breached bool) int {
|
||||
if ruleID <= 0 {
|
||||
return 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
state, ok := s.ruleStates[ruleID]
|
||||
if !ok {
|
||||
state = &opsAlertRuleState{}
|
||||
s.ruleStates[ruleID] = state
|
||||
}
|
||||
|
||||
if !state.LastEvaluatedAt.IsZero() && interval > 0 {
|
||||
if now.Sub(state.LastEvaluatedAt) > interval*2 {
|
||||
state.ConsecutiveBreaches = 0
|
||||
}
|
||||
}
|
||||
|
||||
state.LastEvaluatedAt = now
|
||||
if breached {
|
||||
state.ConsecutiveBreaches++
|
||||
} else {
|
||||
state.ConsecutiveBreaches = 0
|
||||
}
|
||||
return state.ConsecutiveBreaches
|
||||
}
|
||||
|
||||
func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int {
|
||||
if sustainedMinutes <= 0 {
|
||||
return 1
|
||||
}
|
||||
if interval <= 0 {
|
||||
return sustainedMinutes
|
||||
}
|
||||
required := int(math.Ceil(float64(sustainedMinutes*60) / interval.Seconds()))
|
||||
if required < 1 {
|
||||
return 1
|
||||
}
|
||||
return required
|
||||
}
|
||||
|
||||
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
|
||||
if filters == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
if v, ok := filters["platform"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
platform = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
if v, ok := filters["group_id"]; ok {
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
if t > 0 {
|
||||
id := int64(t)
|
||||
groupID = &id
|
||||
}
|
||||
case int64:
|
||||
if t > 0 {
|
||||
id := t
|
||||
groupID = &id
|
||||
}
|
||||
case int:
|
||||
if t > 0 {
|
||||
id := int64(t)
|
||||
groupID = &id
|
||||
}
|
||||
case string:
|
||||
n, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64)
|
||||
if err == nil && n > 0 {
|
||||
groupID = &n
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := filters["region"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
vv := strings.TrimSpace(s)
|
||||
if vv != "" {
|
||||
region = &vv
|
||||
}
|
||||
}
|
||||
}
|
||||
return platform, groupID, region
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||
ctx context.Context,
|
||||
rule *OpsAlertRule,
|
||||
systemMetrics *OpsSystemMetricsSnapshot,
|
||||
start time.Time,
|
||||
end time.Time,
|
||||
platform string,
|
||||
groupID *int64,
|
||||
) (float64, bool) {
|
||||
if rule == nil {
|
||||
return 0, false
|
||||
}
|
||||
switch strings.TrimSpace(rule.MetricType) {
|
||||
case "cpu_usage_percent":
|
||||
if systemMetrics != nil && systemMetrics.CPUUsagePercent != nil {
|
||||
return *systemMetrics.CPUUsagePercent, true
|
||||
}
|
||||
return 0, false
|
||||
case "memory_usage_percent":
|
||||
if systemMetrics != nil && systemMetrics.MemoryUsagePercent != nil {
|
||||
return *systemMetrics.MemoryUsagePercent, true
|
||||
}
|
||||
return 0, false
|
||||
case "concurrency_queue_depth":
|
||||
if systemMetrics != nil && systemMetrics.ConcurrencyQueueDepth != nil {
|
||||
return float64(*systemMetrics.ConcurrencyQueueDepth), true
|
||||
}
|
||||
return 0, false
|
||||
case "group_available_accounts":
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
if availability.Group == nil {
|
||||
return 0, true
|
||||
}
|
||||
return float64(availability.Group.AvailableCount), true
|
||||
case "group_available_ratio":
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
return computeGroupAvailableRatio(availability.Group), true
|
||||
case "account_rate_limited_count":
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.IsRateLimited
|
||||
})), true
|
||||
case "account_error_count":
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||
})), true
|
||||
}
|
||||
|
||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: platform,
|
||||
GroupID: groupID,
|
||||
QueryMode: OpsQueryModeRaw,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
if overview == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch strings.TrimSpace(rule.MetricType) {
|
||||
case "success_rate":
|
||||
if overview.RequestCountSLA <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return overview.SLA * 100, true
|
||||
case "error_rate":
|
||||
if overview.RequestCountSLA <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return overview.ErrorRate * 100, true
|
||||
case "upstream_error_rate":
|
||||
if overview.RequestCountSLA <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return overview.UpstreamErrorRate * 100, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func compareMetric(value float64, operator string, threshold float64) bool {
|
||||
switch strings.TrimSpace(operator) {
|
||||
case ">":
|
||||
return value > threshold
|
||||
case ">=":
|
||||
return value >= threshold
|
||||
case "<":
|
||||
return value < threshold
|
||||
case "<=":
|
||||
return value <= threshold
|
||||
case "==":
|
||||
return value == threshold
|
||||
case "!=":
|
||||
return value != threshold
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func buildOpsAlertDimensions(platform string, groupID *int64) map[string]any {
|
||||
dims := map[string]any{}
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
dims["platform"] = strings.TrimSpace(platform)
|
||||
}
|
||||
if groupID != nil && *groupID > 0 {
|
||||
dims["group_id"] = *groupID
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil
|
||||
}
|
||||
return dims
|
||||
}
|
||||
|
||||
func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes int, platform string, groupID *int64) string {
|
||||
if rule == nil {
|
||||
return ""
|
||||
}
|
||||
scope := "overall"
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
scope = fmt.Sprintf("platform=%s", strings.TrimSpace(platform))
|
||||
}
|
||||
if groupID != nil && *groupID > 0 {
|
||||
scope = fmt.Sprintf("%s group_id=%d", scope, *groupID)
|
||||
}
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
return fmt.Sprintf("%s %s %.2f (current %.2f) over last %dm (%s)",
|
||||
strings.TrimSpace(rule.MetricType),
|
||||
strings.TrimSpace(rule.Operator),
|
||||
rule.Threshold,
|
||||
value,
|
||||
windowMinutes,
|
||||
strings.TrimSpace(scope),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
|
||||
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
|
||||
return false
|
||||
}
|
||||
if event.EmailSent {
|
||||
return false
|
||||
}
|
||||
if !rule.NotifyEmail {
|
||||
return false
|
||||
}
|
||||
|
||||
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
||||
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(emailCfg.Alert.Recipients) == 0 {
|
||||
return false
|
||||
}
|
||||
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
|
||||
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Apply/update rate limiter.
|
||||
s.emailLimiter.SetLimit(emailCfg.Alert.RateLimitPerHour)
|
||||
|
||||
subject := fmt.Sprintf("[Ops Alert][%s] %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name))
|
||||
body := buildOpsAlertEmailBody(rule, event)
|
||||
|
||||
anySent := false
|
||||
for _, to := range emailCfg.Alert.Recipients {
|
||||
addr := strings.TrimSpace(to)
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
if !s.emailLimiter.Allow(time.Now().UTC()) {
|
||||
continue
|
||||
}
|
||||
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
|
||||
// Ignore per-recipient failures; continue best-effort.
|
||||
continue
|
||||
}
|
||||
anySent = true
|
||||
}
|
||||
|
||||
if anySent {
|
||||
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
|
||||
}
|
||||
return anySent
|
||||
}
|
||||
|
||||
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||
if rule == nil || event == nil {
|
||||
return ""
|
||||
}
|
||||
metric := strings.TrimSpace(rule.MetricType)
|
||||
value := "-"
|
||||
threshold := fmt.Sprintf("%.2f", rule.Threshold)
|
||||
if event.MetricValue != nil {
|
||||
value = fmt.Sprintf("%.2f", *event.MetricValue)
|
||||
}
|
||||
if event.ThresholdValue != nil {
|
||||
threshold = fmt.Sprintf("%.2f", *event.ThresholdValue)
|
||||
}
|
||||
return fmt.Sprintf(`
|
||||
<h2>Ops Alert</h2>
|
||||
<p><b>Rule</b>: %s</p>
|
||||
<p><b>Severity</b>: %s</p>
|
||||
<p><b>Status</b>: %s</p>
|
||||
<p><b>Metric</b>: %s %s %s</p>
|
||||
<p><b>Fired at</b>: %s</p>
|
||||
<p><b>Description</b>: %s</p>
|
||||
`,
|
||||
htmlEscape(rule.Name),
|
||||
htmlEscape(rule.Severity),
|
||||
htmlEscape(event.Status),
|
||||
htmlEscape(metric),
|
||||
htmlEscape(rule.Operator),
|
||||
htmlEscape(fmt.Sprintf("%s (threshold %s)", value, threshold)),
|
||||
event.FiredAt.Format(time.RFC3339),
|
||||
htmlEscape(event.Description),
|
||||
)
|
||||
}
|
||||
|
||||
func shouldSendOpsAlertEmailByMinSeverity(minSeverity string, ruleSeverity string) bool {
|
||||
minSeverity = strings.ToLower(strings.TrimSpace(minSeverity))
|
||||
if minSeverity == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
eventLevel := opsEmailSeverityForOps(ruleSeverity)
|
||||
minLevel := strings.ToLower(minSeverity)
|
||||
|
||||
rank := func(level string) int {
|
||||
switch level {
|
||||
case "critical":
|
||||
return 3
|
||||
case "warning":
|
||||
return 2
|
||||
case "info":
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return rank(eventLevel) >= rank(minLevel)
|
||||
}
|
||||
|
||||
func opsEmailSeverityForOps(severity string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(severity)) {
|
||||
case "P0":
|
||||
return "critical"
|
||||
case "P1":
|
||||
return "warning"
|
||||
default:
|
||||
return "info"
|
||||
}
|
||||
}
|
||||
|
||||
func isOpsAlertSilenced(now time.Time, rule *OpsAlertRule, event *OpsAlertEvent, silencing OpsAlertSilencingSettings) bool {
|
||||
if !silencing.Enabled {
|
||||
return false
|
||||
}
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
if strings.TrimSpace(silencing.GlobalUntilRFC3339) != "" {
|
||||
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(silencing.GlobalUntilRFC3339)); err == nil {
|
||||
if now.Before(t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, entry := range silencing.Entries {
|
||||
untilRaw := strings.TrimSpace(entry.UntilRFC3339)
|
||||
if untilRaw == "" {
|
||||
continue
|
||||
}
|
||||
until, err := time.Parse(time.RFC3339, untilRaw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if now.After(until) {
|
||||
continue
|
||||
}
|
||||
if entry.RuleID != nil && rule != nil && rule.ID > 0 && *entry.RuleID != rule.ID {
|
||||
continue
|
||||
}
|
||||
if len(entry.Severities) > 0 {
|
||||
match := false
|
||||
for _, s := range entry.Severities {
|
||||
if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(event.Severity)) || strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(rule.Severity)) {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, lock OpsDistributedLockSettings) (func(), bool) {
|
||||
if !lock.Enabled {
|
||||
return nil, true
|
||||
}
|
||||
if s.redisClient == nil {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock")
|
||||
})
|
||||
return nil, true
|
||||
}
|
||||
key := strings.TrimSpace(lock.Key)
|
||||
if key == "" {
|
||||
key = opsAlertEvaluatorLeaderLockKey
|
||||
}
|
||||
ttl := time.Duration(lock.TTLSeconds) * time.Second
|
||||
if ttl <= 0 {
|
||||
ttl = opsAlertEvaluatorLeaderLockTTL
|
||||
}
|
||||
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err != nil {
|
||||
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
|
||||
// Single-node deployments can disable the distributed lock via runtime settings.
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
if !ok {
|
||||
s.maybeLogSkip(key)
|
||||
return nil, false
|
||||
}
|
||||
return func() {
|
||||
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
|
||||
s.skipLogMu.Lock()
|
||||
defer s.skipLogMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < opsAlertEvaluatorSkipLogInterval {
|
||||
return
|
||||
}
|
||||
s.skipLogAt = now
|
||||
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
msg := strings.TrimSpace(result)
|
||||
if msg == "" {
|
||||
msg = "ok"
|
||||
}
|
||||
msg = truncateString(msg, 2048)
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAlertEvaluatorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
LastResult: &msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
|
||||
if s == nil || s.opsRepo == nil || err == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAlertEvaluatorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &now,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
|
||||
func htmlEscape(s string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
"&", "&",
|
||||
"<", "<",
|
||||
">", ">",
|
||||
`"`, """,
|
||||
"'", "'",
|
||||
)
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
|
||||
type slidingWindowLimiter struct {
|
||||
mu sync.Mutex
|
||||
limit int
|
||||
window time.Duration
|
||||
sent []time.Time
|
||||
}
|
||||
|
||||
func newSlidingWindowLimiter(limit int, window time.Duration) *slidingWindowLimiter {
|
||||
if window <= 0 {
|
||||
window = time.Hour
|
||||
}
|
||||
return &slidingWindowLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
sent: []time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *slidingWindowLimiter) SetLimit(limit int) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.limit = limit
|
||||
}
|
||||
|
||||
func (l *slidingWindowLimiter) Allow(now time.Time) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.limit <= 0 {
|
||||
return true
|
||||
}
|
||||
cutoff := now.Add(-l.window)
|
||||
keep := l.sent[:0]
|
||||
for _, t := range l.sent {
|
||||
if t.After(cutoff) {
|
||||
keep = append(keep, t)
|
||||
}
|
||||
}
|
||||
l.sent = keep
|
||||
if len(l.sent) >= l.limit {
|
||||
return false
|
||||
}
|
||||
l.sent = append(l.sent, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// computeGroupAvailableRatio returns the available percentage for a group.
|
||||
// Formula: (AvailableCount / TotalAccounts) * 100.
|
||||
// Returns 0 when TotalAccounts is 0.
|
||||
func computeGroupAvailableRatio(group *GroupAvailability) float64 {
|
||||
if group == nil || group.TotalAccounts <= 0 {
|
||||
return 0
|
||||
}
|
||||
return (float64(group.AvailableCount) / float64(group.TotalAccounts)) * 100
|
||||
}
|
||||
|
||||
// countAccountsByCondition counts accounts that satisfy the given condition.
|
||||
func countAccountsByCondition(accounts map[int64]*AccountAvailability, condition func(*AccountAvailability) bool) int64 {
|
||||
if len(accounts) == 0 || condition == nil {
|
||||
return 0
|
||||
}
|
||||
var count int64
|
||||
for _, account := range accounts {
|
||||
if account != nil && condition(account) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
210
backend/internal/service/ops_alert_evaluator_service_test.go
Normal file
210
backend/internal/service/ops_alert_evaluator_service_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubOpsRepo struct {
|
||||
OpsRepository
|
||||
overview *OpsDashboardOverview
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubOpsRepo) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.overview != nil {
|
||||
return s.overview, nil
|
||||
}
|
||||
return &OpsDashboardOverview{}, nil
|
||||
}
|
||||
|
||||
func TestComputeGroupAvailableRatio(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("正常情况: 10个账号, 8个可用 = 80%", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||
TotalAccounts: 10,
|
||||
AvailableCount: 8,
|
||||
})
|
||||
require.InDelta(t, 80.0, got, 0.0001)
|
||||
})
|
||||
|
||||
t.Run("边界情况: TotalAccounts = 0 应返回 0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||
TotalAccounts: 0,
|
||||
AvailableCount: 8,
|
||||
})
|
||||
require.Equal(t, 0.0, got)
|
||||
})
|
||||
|
||||
t.Run("边界情况: AvailableCount = 0 应返回 0%", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||
TotalAccounts: 10,
|
||||
AvailableCount: 0,
|
||||
})
|
||||
require.Equal(t, 0.0, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCountAccountsByCondition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("测试限流账号统计: acc.IsRateLimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
accounts := map[int64]*AccountAvailability{
|
||||
1: {IsRateLimited: true},
|
||||
2: {IsRateLimited: false},
|
||||
3: {IsRateLimited: true},
|
||||
}
|
||||
|
||||
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.IsRateLimited
|
||||
})
|
||||
require.Equal(t, int64(2), got)
|
||||
})
|
||||
|
||||
t.Run("测试错误账号统计(排除临时不可调度): acc.HasError && acc.TempUnschedulableUntil == nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
until := time.Now().UTC().Add(5 * time.Minute)
|
||||
accounts := map[int64]*AccountAvailability{
|
||||
1: {HasError: true},
|
||||
2: {HasError: true, TempUnschedulableUntil: &until},
|
||||
3: {HasError: false},
|
||||
}
|
||||
|
||||
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||
})
|
||||
require.Equal(t, int64(1), got)
|
||||
})
|
||||
|
||||
t.Run("边界情况: 空 map 应返回 0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := countAccountsByCondition(map[int64]*AccountAvailability{}, func(acc *AccountAvailability) bool {
|
||||
return acc.IsRateLimited
|
||||
})
|
||||
require.Equal(t, int64(0), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestComputeRuleMetricNewIndicators(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
groupID := int64(101)
|
||||
platform := "openai"
|
||||
|
||||
availability := &OpsAccountAvailability{
|
||||
Group: &GroupAvailability{
|
||||
GroupID: groupID,
|
||||
TotalAccounts: 10,
|
||||
AvailableCount: 8,
|
||||
},
|
||||
Accounts: map[int64]*AccountAvailability{
|
||||
1: {IsRateLimited: true},
|
||||
2: {IsRateLimited: true},
|
||||
3: {HasError: true},
|
||||
4: {HasError: true, TempUnschedulableUntil: timePtr(time.Now().UTC().Add(2 * time.Minute))},
|
||||
5: {HasError: false, IsRateLimited: false},
|
||||
},
|
||||
}
|
||||
|
||||
opsService := &OpsService{
|
||||
getAccountAvailability: func(_ context.Context, _ string, _ *int64) (*OpsAccountAvailability, error) {
|
||||
return availability, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpsAlertEvaluatorService{
|
||||
opsService: opsService,
|
||||
opsRepo: &stubOpsRepo{overview: &OpsDashboardOverview{}},
|
||||
}
|
||||
|
||||
start := time.Now().UTC().Add(-5 * time.Minute)
|
||||
end := time.Now().UTC()
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
metricType string
|
||||
groupID *int64
|
||||
wantValue float64
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "group_available_accounts",
|
||||
metricType: "group_available_accounts",
|
||||
groupID: &groupID,
|
||||
wantValue: 8,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "group_available_ratio",
|
||||
metricType: "group_available_ratio",
|
||||
groupID: &groupID,
|
||||
wantValue: 80.0,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "account_rate_limited_count",
|
||||
metricType: "account_rate_limited_count",
|
||||
groupID: nil,
|
||||
wantValue: 2,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "account_error_count",
|
||||
metricType: "account_error_count",
|
||||
groupID: nil,
|
||||
wantValue: 1,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "group_available_accounts without group_id returns false",
|
||||
metricType: "group_available_accounts",
|
||||
groupID: nil,
|
||||
wantValue: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "group_available_ratio without group_id returns false",
|
||||
metricType: "group_available_ratio",
|
||||
groupID: nil,
|
||||
wantValue: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rule := &OpsAlertRule{
|
||||
MetricType: tt.metricType,
|
||||
}
|
||||
gotValue, gotOK := svc.computeRuleMetric(ctx, rule, nil, start, end, platform, tt.groupID)
|
||||
require.Equal(t, tt.wantOK, gotOK)
|
||||
if !tt.wantOK {
|
||||
return
|
||||
}
|
||||
require.InDelta(t, tt.wantValue, gotValue, 0.0001)
|
||||
})
|
||||
}
|
||||
}
|
||||
95
backend/internal/service/ops_alert_models.go
Normal file
95
backend/internal/service/ops_alert_models.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// Ops alert rule/event models.
|
||||
//
|
||||
// NOTE: These are admin-facing DTOs and intentionally keep JSON naming aligned
|
||||
// with the existing ops dashboard frontend (backup style).
|
||||
|
||||
const (
|
||||
OpsAlertStatusFiring = "firing"
|
||||
OpsAlertStatusResolved = "resolved"
|
||||
OpsAlertStatusManualResolved = "manual_resolved"
|
||||
)
|
||||
|
||||
type OpsAlertRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
|
||||
Enabled bool `json:"enabled"`
|
||||
Severity string `json:"severity"`
|
||||
|
||||
MetricType string `json:"metric_type"`
|
||||
Operator string `json:"operator"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
|
||||
WindowMinutes int `json:"window_minutes"`
|
||||
SustainedMinutes int `json:"sustained_minutes"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
|
||||
NotifyEmail bool `json:"notify_email"`
|
||||
|
||||
Filters map[string]any `json:"filters,omitempty"`
|
||||
|
||||
LastTriggeredAt *time.Time `json:"last_triggered_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEvent struct {
|
||||
ID int64 `json:"id"`
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
|
||||
MetricValue *float64 `json:"metric_value,omitempty"`
|
||||
ThresholdValue *float64 `json:"threshold_value,omitempty"`
|
||||
|
||||
Dimensions map[string]any `json:"dimensions,omitempty"`
|
||||
|
||||
FiredAt time.Time `json:"fired_at"`
|
||||
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
|
||||
|
||||
EmailSent bool `json:"email_sent"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type OpsAlertSilence struct {
|
||||
ID int64 `json:"id"`
|
||||
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Region *string `json:"region,omitempty"`
|
||||
|
||||
Until time.Time `json:"until"`
|
||||
Reason string `json:"reason"`
|
||||
|
||||
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEventFilter struct {
|
||||
Limit int
|
||||
|
||||
// Cursor pagination (descending by fired_at, then id).
|
||||
BeforeFiredAt *time.Time
|
||||
BeforeID *int64
|
||||
|
||||
// Optional filters.
|
||||
Status string
|
||||
Severity string
|
||||
EmailSent *bool
|
||||
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
// Dimensions filters (best-effort).
|
||||
Platform string
|
||||
GroupID *int64
|
||||
}
|
||||
232
backend/internal/service/ops_alerts.go
Normal file
232
backend/internal/service/ops_alerts.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return []*OpsAlertRule{}, nil
|
||||
}
|
||||
return s.opsRepo.ListAlertRules(ctx)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertRule(ctx context.Context, rule *OpsAlertRule) (*OpsAlertRule, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if rule == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE", "invalid rule")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertRule(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertRule(ctx context.Context, rule *OpsAlertRule) (*OpsAlertRule, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if rule == nil || rule.ID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE", "invalid rule")
|
||||
}
|
||||
|
||||
updated, err := s.opsRepo.UpdateAlertRule(ctx, rule)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_RULE_NOT_FOUND", "alert rule not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if id <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if err := s.opsRepo.DeleteAlertRule(ctx, id); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return infraerrors.NotFound("OPS_ALERT_RULE_NOT_FOUND", "alert rule not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return []*OpsAlertEvent{}, nil
|
||||
}
|
||||
return s.opsRepo.ListAlertEvents(ctx, filter)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ev == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
|
||||
}
|
||||
if input.RuleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if strings.TrimSpace(input.Platform) == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
|
||||
}
|
||||
if input.Until.IsZero() {
|
||||
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if strings.TrimSpace(platform) == "" {
|
||||
return false, nil
|
||||
}
|
||||
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
return s.opsRepo.GetLatestAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if event == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_EVENT", "invalid event")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertEvent(ctx, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
status = strings.TrimSpace(status)
|
||||
if status == "" {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||
}
|
||||
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||
}
|
||||
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
return s.opsRepo.UpdateAlertEventEmailSent(ctx, eventID, emailSent)
|
||||
}
|
||||
367
backend/internal/service/ops_cleanup_service.go
Normal file
367
backend/internal/service/ops_cleanup_service.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
opsCleanupJobName = "ops_cleanup"
|
||||
|
||||
opsCleanupLeaderLockKeyDefault = "ops:cleanup:leader"
|
||||
opsCleanupLeaderLockTTLDefault = 30 * time.Minute
|
||||
)
|
||||
|
||||
var opsCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
|
||||
var opsCleanupReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// OpsCleanupService periodically deletes old ops data to prevent unbounded DB growth.
|
||||
//
|
||||
// - Scheduling: 5-field cron spec (minute hour dom month dow).
|
||||
// - Multi-instance: best-effort Redis leader lock so only one node runs cleanup.
|
||||
// - Safety: deletes in batches to avoid long transactions.
|
||||
type OpsCleanupService struct {
|
||||
opsRepo OpsRepository
|
||||
db *sql.DB
|
||||
redisClient *redis.Client
|
||||
cfg *config.Config
|
||||
|
||||
instanceID string
|
||||
|
||||
cron *cron.Cron
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
warnNoRedisOnce sync.Once
|
||||
}
|
||||
|
||||
func NewOpsCleanupService(
|
||||
opsRepo OpsRepository,
|
||||
db *sql.DB,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsCleanupService {
|
||||
return &OpsCleanupService{
|
||||
opsRepo: opsRepo,
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
cfg: cfg,
|
||||
instanceID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled {
|
||||
log.Printf("[OpsCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.opsRepo == nil || s.db == nil {
|
||||
log.Printf("[OpsCleanup] not started (missing deps)")
|
||||
return
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
schedule := "0 2 * * *"
|
||||
if s.cfg != nil && strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule) != "" {
|
||||
schedule = strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule)
|
||||
}
|
||||
|
||||
loc := time.Local
|
||||
if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" {
|
||||
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
|
||||
loc = parsed
|
||||
}
|
||||
}
|
||||
|
||||
c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc))
|
||||
_, err := c.AddFunc(schedule, func() { s.runScheduled() })
|
||||
if err != nil {
|
||||
log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
return
|
||||
}
|
||||
s.cron = c
|
||||
s.cron.Start()
|
||||
log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.cron != nil {
|
||||
ctx := s.cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Printf("[OpsCleanup] cron stop timed out")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) runScheduled() {
|
||||
if s == nil || s.db == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
counts, err := s.runCleanupOnce(ctx)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
log.Printf("[OpsCleanup] cleanup failed: %v", err)
|
||||
return
|
||||
}
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
|
||||
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
|
||||
}
|
||||
|
||||
type opsCleanupDeletedCounts struct {
|
||||
errorLogs int64
|
||||
retryAttempts int64
|
||||
alertEvents int64
|
||||
systemMetrics int64
|
||||
hourlyPreagg int64
|
||||
dailyPreagg int64
|
||||
}
|
||||
|
||||
func (c opsCleanupDeletedCounts) String() string {
|
||||
return fmt.Sprintf(
|
||||
"error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
|
||||
c.errorLogs,
|
||||
c.retryAttempts,
|
||||
c.alertEvents,
|
||||
c.systemMetrics,
|
||||
c.hourlyPreagg,
|
||||
c.dailyPreagg,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
|
||||
out := opsCleanupDeletedCounts{}
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
batchSize := 5000
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Error-like tables: error logs / retry attempts / alert events.
|
||||
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.errorLogs = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.retryAttempts = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.alertEvents = n
|
||||
}
|
||||
|
||||
// Minute-level metrics snapshots.
|
||||
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.systemMetrics = n
|
||||
}
|
||||
|
||||
// Pre-aggregation tables (hourly/daily).
|
||||
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.hourlyPreagg = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.dailyPreagg = n
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func deleteOldRowsByID(
|
||||
ctx context.Context,
|
||||
db *sql.DB,
|
||||
table string,
|
||||
timeColumn string,
|
||||
cutoff time.Time,
|
||||
batchSize int,
|
||||
castCutoffToDate bool,
|
||||
) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if batchSize <= 0 {
|
||||
batchSize = 5000
|
||||
}
|
||||
|
||||
where := fmt.Sprintf("%s < $1", timeColumn)
|
||||
if castCutoffToDate {
|
||||
where = fmt.Sprintf("%s < $1::date", timeColumn)
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
WITH batch AS (
|
||||
SELECT id FROM %s
|
||||
WHERE %s
|
||||
ORDER BY id
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM %s
|
||||
WHERE id IN (SELECT id FROM batch)
|
||||
`, table, where, table)
|
||||
|
||||
var total int64
|
||||
for {
|
||||
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
|
||||
if err != nil {
|
||||
// If ops tables aren't present yet (partial deployments), treat as no-op.
|
||||
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
|
||||
return total, nil
|
||||
}
|
||||
return total, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += affected
|
||||
if affected == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||
if s == nil {
|
||||
return nil, false
|
||||
}
|
||||
// In simple run mode, assume single instance.
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
key := opsCleanupLeaderLockKeyDefault
|
||||
ttl := opsCleanupLeaderLockTTLDefault
|
||||
|
||||
// Prefer Redis leader lock when available, but avoid stampeding the DB when Redis is flaky by
|
||||
// falling back to a DB advisory lock.
|
||||
if s.redisClient != nil {
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err == nil {
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return func() {
|
||||
_, _ = opsCleanupReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}, true
|
||||
}
|
||||
// Redis error: fall back to DB advisory lock.
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err)
|
||||
})
|
||||
} else {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsCleanup] redis not configured; using DB advisory lock")
|
||||
})
|
||||
}
|
||||
|
||||
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, counts opsCleanupDeletedCounts) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
result := truncateString(counts.String(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsCleanupJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
LastResult: &result,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
|
||||
if s == nil || s.opsRepo == nil || err == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsCleanupJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &now,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
257
backend/internal/service/ops_concurrency.go
Normal file
257
backend/internal/service/ops_concurrency.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAccountsPageSize = 100
|
||||
opsConcurrencyBatchChunkSize = 200
|
||||
)
|
||||
|
||||
func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter string) ([]Account, error) {
|
||||
if s == nil || s.accountRepo == nil {
|
||||
return []Account{}, nil
|
||||
}
|
||||
|
||||
out := make([]Account, 0, 128)
|
||||
page := 1
|
||||
for {
|
||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: opsAccountsPageSize,
|
||||
}, platformFilter, "", "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
out = append(out, accounts...)
|
||||
if pageInfo != nil && int64(len(out)) >= pageInfo.Total {
|
||||
break
|
||||
}
|
||||
if len(accounts) < opsAccountsPageSize {
|
||||
break
|
||||
}
|
||||
|
||||
page++
|
||||
if page > 10_000 {
|
||||
log.Printf("[Ops] listAllAccountsForOps: aborting after too many pages (platform=%q)", platformFilter)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts []Account) map[int64]*AccountLoadInfo {
|
||||
if s == nil || s.concurrencyService == nil {
|
||||
return map[int64]*AccountLoadInfo{}
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*AccountLoadInfo{}
|
||||
}
|
||||
|
||||
// De-duplicate IDs (and keep the max concurrency to avoid under-reporting).
|
||||
unique := make(map[int64]int, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev {
|
||||
unique[acc.ID] = acc.Concurrency
|
||||
}
|
||||
}
|
||||
|
||||
batch := make([]AccountWithConcurrency, 0, len(unique))
|
||||
for id, maxConc := range unique {
|
||||
batch = append(batch, AccountWithConcurrency{
|
||||
ID: id,
|
||||
MaxConcurrency: maxConc,
|
||||
})
|
||||
}
|
||||
|
||||
out := make(map[int64]*AccountLoadInfo, len(batch))
|
||||
for i := 0; i < len(batch); i += opsConcurrencyBatchChunkSize {
|
||||
end := i + opsConcurrencyBatchChunkSize
|
||||
if end > len(batch) {
|
||||
end = len(batch)
|
||||
}
|
||||
part, err := s.concurrencyService.GetAccountsLoadBatch(ctx, batch[i:end])
|
||||
if err != nil {
|
||||
// Best-effort: return zeros rather than failing the ops UI.
|
||||
log.Printf("[Ops] GetAccountsLoadBatch failed: %v", err)
|
||||
continue
|
||||
}
|
||||
for k, v := range part {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
|
||||
//
|
||||
// Optional filters:
|
||||
// - platformFilter: only include accounts in that platform (best-effort reduces DB load)
|
||||
// - groupIDFilter: only include accounts that belong to that group
|
||||
func (s *OpsService) GetConcurrencyStats(
|
||||
ctx context.Context,
|
||||
platformFilter string,
|
||||
groupIDFilter *int64,
|
||||
) (map[string]*PlatformConcurrencyInfo, map[int64]*GroupConcurrencyInfo, map[int64]*AccountConcurrencyInfo, *time.Time, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
collectedAt := time.Now()
|
||||
loadMap := s.getAccountsLoadMapBestEffort(ctx, accounts)
|
||||
|
||||
platform := make(map[string]*PlatformConcurrencyInfo)
|
||||
group := make(map[int64]*GroupConcurrencyInfo)
|
||||
account := make(map[int64]*AccountConcurrencyInfo)
|
||||
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var matchedGroup *Group
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if grp.ID == *groupIDFilter {
|
||||
matchedGroup = grp
|
||||
break
|
||||
}
|
||||
}
|
||||
// Group filter provided: skip accounts not in that group.
|
||||
if matchedGroup == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
load := loadMap[acc.ID]
|
||||
currentInUse := int64(0)
|
||||
waiting := int64(0)
|
||||
if load != nil {
|
||||
currentInUse = int64(load.CurrentConcurrency)
|
||||
waiting = int64(load.WaitingCount)
|
||||
}
|
||||
|
||||
// Account-level view picks one display group (the first group).
|
||||
displayGroupID := int64(0)
|
||||
displayGroupName := ""
|
||||
if matchedGroup != nil {
|
||||
displayGroupID = matchedGroup.ID
|
||||
displayGroupName = matchedGroup.Name
|
||||
} else if len(acc.Groups) > 0 && acc.Groups[0] != nil {
|
||||
displayGroupID = acc.Groups[0].ID
|
||||
displayGroupName = acc.Groups[0].Name
|
||||
}
|
||||
|
||||
if _, ok := account[acc.ID]; !ok {
|
||||
info := &AccountConcurrencyInfo{
|
||||
AccountID: acc.ID,
|
||||
AccountName: acc.Name,
|
||||
Platform: acc.Platform,
|
||||
GroupID: displayGroupID,
|
||||
GroupName: displayGroupName,
|
||||
CurrentInUse: currentInUse,
|
||||
MaxCapacity: int64(acc.Concurrency),
|
||||
WaitingInQueue: waiting,
|
||||
}
|
||||
if info.MaxCapacity > 0 {
|
||||
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
|
||||
}
|
||||
account[acc.ID] = info
|
||||
}
|
||||
|
||||
// Platform aggregation.
|
||||
if acc.Platform != "" {
|
||||
if _, ok := platform[acc.Platform]; !ok {
|
||||
platform[acc.Platform] = &PlatformConcurrencyInfo{
|
||||
Platform: acc.Platform,
|
||||
}
|
||||
}
|
||||
p := platform[acc.Platform]
|
||||
p.MaxCapacity += int64(acc.Concurrency)
|
||||
p.CurrentInUse += currentInUse
|
||||
p.WaitingInQueue += waiting
|
||||
}
|
||||
|
||||
// Group aggregation (one account may contribute to multiple groups).
|
||||
if matchedGroup != nil {
|
||||
grp := matchedGroup
|
||||
if _, ok := group[grp.ID]; !ok {
|
||||
group[grp.ID] = &GroupConcurrencyInfo{
|
||||
GroupID: grp.ID,
|
||||
GroupName: grp.Name,
|
||||
Platform: grp.Platform,
|
||||
}
|
||||
}
|
||||
g := group[grp.ID]
|
||||
if g.GroupName == "" && grp.Name != "" {
|
||||
g.GroupName = grp.Name
|
||||
}
|
||||
if g.Platform != "" && grp.Platform != "" && g.Platform != grp.Platform {
|
||||
// Groups are expected to be platform-scoped. If mismatch is observed, avoid misleading labels.
|
||||
g.Platform = ""
|
||||
}
|
||||
g.MaxCapacity += int64(acc.Concurrency)
|
||||
g.CurrentInUse += currentInUse
|
||||
g.WaitingInQueue += waiting
|
||||
} else {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := group[grp.ID]; !ok {
|
||||
group[grp.ID] = &GroupConcurrencyInfo{
|
||||
GroupID: grp.ID,
|
||||
GroupName: grp.Name,
|
||||
Platform: grp.Platform,
|
||||
}
|
||||
}
|
||||
g := group[grp.ID]
|
||||
if g.GroupName == "" && grp.Name != "" {
|
||||
g.GroupName = grp.Name
|
||||
}
|
||||
if g.Platform != "" && grp.Platform != "" && g.Platform != grp.Platform {
|
||||
// Groups are expected to be platform-scoped. If mismatch is observed, avoid misleading labels.
|
||||
g.Platform = ""
|
||||
}
|
||||
g.MaxCapacity += int64(acc.Concurrency)
|
||||
g.CurrentInUse += currentInUse
|
||||
g.WaitingInQueue += waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, info := range platform {
|
||||
if info.MaxCapacity > 0 {
|
||||
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
|
||||
}
|
||||
}
|
||||
for _, info := range group {
|
||||
if info.MaxCapacity > 0 {
|
||||
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
|
||||
}
|
||||
}
|
||||
|
||||
return platform, group, account, &collectedAt, nil
|
||||
}
|
||||
90
backend/internal/service/ops_dashboard.go
Normal file
90
backend/internal/service/ops_dashboard.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
|
||||
// Resolve query mode (requested via query param, or DB default).
|
||||
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||
|
||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, filter)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrOpsPreaggregatedNotPopulated) {
|
||||
return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Best-effort system health + jobs; dashboard metrics should still render if these are missing.
|
||||
if metrics, err := s.opsRepo.GetLatestSystemMetrics(ctx, 1); err == nil {
|
||||
// Attach config-derived limits so the UI can show "current / max" for connection pools.
|
||||
// These are best-effort and should never block the dashboard rendering.
|
||||
if s != nil && s.cfg != nil {
|
||||
if s.cfg.Database.MaxOpenConns > 0 {
|
||||
metrics.DBMaxOpenConns = intPtr(s.cfg.Database.MaxOpenConns)
|
||||
}
|
||||
if s.cfg.Redis.PoolSize > 0 {
|
||||
metrics.RedisPoolSize = intPtr(s.cfg.Redis.PoolSize)
|
||||
}
|
||||
}
|
||||
overview.SystemMetrics = metrics
|
||||
} else if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
log.Printf("[Ops] GetLatestSystemMetrics failed: %v", err)
|
||||
}
|
||||
|
||||
if heartbeats, err := s.opsRepo.ListJobHeartbeats(ctx); err == nil {
|
||||
overview.JobHeartbeats = heartbeats
|
||||
} else {
|
||||
log.Printf("[Ops] ListJobHeartbeats failed: %v", err)
|
||||
}
|
||||
|
||||
overview.HealthScore = computeDashboardHealthScore(time.Now().UTC(), overview)
|
||||
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) resolveOpsQueryMode(ctx context.Context, requested OpsQueryMode) OpsQueryMode {
|
||||
if requested.IsValid() {
|
||||
// Allow "auto" to be disabled via config until preagg is proven stable in production.
|
||||
// Forced `preagg` via query param still works.
|
||||
if requested == OpsQueryModeAuto && s != nil && s.cfg != nil && !s.cfg.Ops.UsePreaggregatedTables {
|
||||
return OpsQueryModeRaw
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
mode := OpsQueryModeAuto
|
||||
if s != nil && s.settingRepo != nil {
|
||||
if raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsQueryModeDefault); err == nil {
|
||||
mode = ParseOpsQueryMode(raw)
|
||||
}
|
||||
}
|
||||
|
||||
if mode == OpsQueryModeAuto && s != nil && s.cfg != nil && !s.cfg.Ops.UsePreaggregatedTables {
|
||||
return OpsQueryModeRaw
|
||||
}
|
||||
return mode
|
||||
}
|
||||
87
backend/internal/service/ops_dashboard_models.go
Normal file
87
backend/internal/service/ops_dashboard_models.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type OpsDashboardFilter struct {
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
|
||||
Platform string
|
||||
GroupID *int64
|
||||
|
||||
// QueryMode controls whether dashboard queries should use raw logs or pre-aggregated tables.
|
||||
// Expected values: auto/raw/preagg (see OpsQueryMode).
|
||||
QueryMode OpsQueryMode
|
||||
}
|
||||
|
||||
type OpsRateSummary struct {
|
||||
Current float64 `json:"current"`
|
||||
Peak float64 `json:"peak"`
|
||||
Avg float64 `json:"avg"`
|
||||
}
|
||||
|
||||
type OpsPercentiles struct {
|
||||
P50 *int `json:"p50_ms"`
|
||||
P90 *int `json:"p90_ms"`
|
||||
P95 *int `json:"p95_ms"`
|
||||
P99 *int `json:"p99_ms"`
|
||||
Avg *int `json:"avg_ms"`
|
||||
Max *int `json:"max_ms"`
|
||||
}
|
||||
|
||||
type OpsDashboardOverview struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
|
||||
// HealthScore is a backend-computed overall health score (0-100).
|
||||
// It is derived from the monitored metrics in this overview, plus best-effort system metrics/job heartbeats.
|
||||
HealthScore int `json:"health_score"`
|
||||
|
||||
// Latest system-level snapshot (window=1m, global).
|
||||
SystemMetrics *OpsSystemMetricsSnapshot `json:"system_metrics"`
|
||||
|
||||
// Background jobs health (heartbeats).
|
||||
JobHeartbeats []*OpsJobHeartbeat `json:"job_heartbeats"`
|
||||
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
ErrorCountTotal int64 `json:"error_count_total"`
|
||||
BusinessLimitedCount int64 `json:"business_limited_count"`
|
||||
|
||||
ErrorCountSLA int64 `json:"error_count_sla"`
|
||||
RequestCountTotal int64 `json:"request_count_total"`
|
||||
RequestCountSLA int64 `json:"request_count_sla"`
|
||||
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
|
||||
SLA float64 `json:"sla"`
|
||||
ErrorRate float64 `json:"error_rate"`
|
||||
UpstreamErrorRate float64 `json:"upstream_error_rate"`
|
||||
UpstreamErrorCountExcl429529 int64 `json:"upstream_error_count_excl_429_529"`
|
||||
Upstream429Count int64 `json:"upstream_429_count"`
|
||||
Upstream529Count int64 `json:"upstream_529_count"`
|
||||
|
||||
QPS OpsRateSummary `json:"qps"`
|
||||
TPS OpsRateSummary `json:"tps"`
|
||||
|
||||
Duration OpsPercentiles `json:"duration"`
|
||||
TTFT OpsPercentiles `json:"ttft"`
|
||||
}
|
||||
|
||||
type OpsLatencyHistogramBucket struct {
|
||||
Range string `json:"range"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
// OpsLatencyHistogramResponse is a coarse latency distribution histogram (success requests only).
|
||||
// It is used by the Ops dashboard to quickly identify tail latency regressions.
|
||||
type OpsLatencyHistogramResponse struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
Buckets []*OpsLatencyHistogramBucket `json:"buckets"`
|
||||
}
|
||||
45
backend/internal/service/ops_errors.go
Normal file
45
backend/internal/service/ops_errors.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetErrorDistribution(ctx, filter)
|
||||
}
|
||||
143
backend/internal/service/ops_health_score.go
Normal file
143
backend/internal/service/ops_health_score.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// computeDashboardHealthScore computes a 0-100 health score from the metrics returned by the dashboard overview.
|
||||
//
|
||||
// Design goals:
|
||||
// - Backend-owned scoring (UI only displays).
|
||||
// - Layered scoring: Business Health (70%) + Infrastructure Health (30%)
|
||||
// - Avoids double-counting (e.g., DB failure affects both infra and business metrics)
|
||||
// - Conservative + stable: penalize clear degradations; avoid overreacting to missing/idle data.
|
||||
func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview) int {
|
||||
if overview == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Idle/no-data: avoid showing a "bad" score when there is no traffic.
|
||||
// UI can still render a gray/idle state based on QPS + error rate.
|
||||
if overview.RequestCountSLA <= 0 && overview.RequestCountTotal <= 0 && overview.ErrorCountTotal <= 0 {
|
||||
return 100
|
||||
}
|
||||
|
||||
businessHealth := computeBusinessHealth(overview)
|
||||
infraHealth := computeInfraHealth(now, overview)
|
||||
|
||||
// Weighted combination: 70% business + 30% infrastructure
|
||||
score := businessHealth*0.7 + infraHealth*0.3
|
||||
return int(math.Round(clampFloat64(score, 0, 100)))
|
||||
}
|
||||
|
||||
// computeBusinessHealth calculates business health score (0-100)
|
||||
// Components: Error Rate (50%) + TTFT (50%)
|
||||
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
||||
// Error rate score: 1% → 100, 10% → 0 (linear)
|
||||
// Combines request errors and upstream errors
|
||||
errorScore := 100.0
|
||||
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
||||
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
||||
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
||||
if combinedErrorPct > 1.0 {
|
||||
if combinedErrorPct <= 10.0 {
|
||||
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
|
||||
} else {
|
||||
errorScore = 0
|
||||
}
|
||||
}
|
||||
|
||||
// TTFT score: 1s → 100, 3s → 0 (linear)
|
||||
// Time to first token is critical for user experience
|
||||
ttftScore := 100.0
|
||||
if overview.TTFT.P99 != nil {
|
||||
p99 := float64(*overview.TTFT.P99)
|
||||
if p99 > 1000 {
|
||||
if p99 <= 3000 {
|
||||
ttftScore = (3000 - p99) / 2000 * 100
|
||||
} else {
|
||||
ttftScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Weighted combination: 50% error rate + 50% TTFT
|
||||
return errorScore*0.5 + ttftScore*0.5
|
||||
}
|
||||
|
||||
// computeInfraHealth calculates infrastructure health score (0-100)
|
||||
// Components: Storage (40%) + Compute Resources (30%) + Background Jobs (30%)
|
||||
func computeInfraHealth(now time.Time, overview *OpsDashboardOverview) float64 {
|
||||
// Storage score: DB critical, Redis less critical
|
||||
storageScore := 100.0
|
||||
if overview.SystemMetrics != nil {
|
||||
if overview.SystemMetrics.DBOK != nil && !*overview.SystemMetrics.DBOK {
|
||||
storageScore = 0 // DB failure is critical
|
||||
} else if overview.SystemMetrics.RedisOK != nil && !*overview.SystemMetrics.RedisOK {
|
||||
storageScore = 50 // Redis failure is degraded but not critical
|
||||
}
|
||||
}
|
||||
|
||||
// Compute resources score: CPU + Memory
|
||||
computeScore := 100.0
|
||||
if overview.SystemMetrics != nil {
|
||||
cpuScore := 100.0
|
||||
if overview.SystemMetrics.CPUUsagePercent != nil {
|
||||
cpuPct := clampFloat64(*overview.SystemMetrics.CPUUsagePercent, 0, 100)
|
||||
if cpuPct > 80 {
|
||||
if cpuPct <= 100 {
|
||||
cpuScore = (100 - cpuPct) / 20 * 100
|
||||
} else {
|
||||
cpuScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memScore := 100.0
|
||||
if overview.SystemMetrics.MemoryUsagePercent != nil {
|
||||
memPct := clampFloat64(*overview.SystemMetrics.MemoryUsagePercent, 0, 100)
|
||||
if memPct > 85 {
|
||||
if memPct <= 100 {
|
||||
memScore = (100 - memPct) / 15 * 100
|
||||
} else {
|
||||
memScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computeScore = (cpuScore + memScore) / 2
|
||||
}
|
||||
|
||||
// Background jobs score
|
||||
jobScore := 100.0
|
||||
failedJobs := 0
|
||||
totalJobs := 0
|
||||
for _, hb := range overview.JobHeartbeats {
|
||||
if hb == nil {
|
||||
continue
|
||||
}
|
||||
totalJobs++
|
||||
if hb.LastErrorAt != nil && (hb.LastSuccessAt == nil || hb.LastErrorAt.After(*hb.LastSuccessAt)) {
|
||||
failedJobs++
|
||||
} else if hb.LastSuccessAt != nil && now.Sub(*hb.LastSuccessAt) > 15*time.Minute {
|
||||
failedJobs++
|
||||
}
|
||||
}
|
||||
if totalJobs > 0 && failedJobs > 0 {
|
||||
jobScore = (1 - float64(failedJobs)/float64(totalJobs)) * 100
|
||||
}
|
||||
|
||||
// Weighted combination
|
||||
return storageScore*0.4 + computeScore*0.3 + jobScore*0.3
|
||||
}
|
||||
|
||||
func clampFloat64(v float64, min float64, max float64) float64 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
442
backend/internal/service/ops_health_score_test.go
Normal file
442
backend/internal/service/ops_health_score_test.go
Normal file
@@ -0,0 +1,442 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestComputeDashboardHealthScore_IdleReturns100(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
score := computeDashboardHealthScore(time.Now().UTC(), &OpsDashboardOverview{})
|
||||
require.Equal(t, 100, score)
|
||||
}
|
||||
|
||||
func TestComputeDashboardHealthScore_DegradesOnBadSignals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ov := &OpsDashboardOverview{
|
||||
RequestCountTotal: 100,
|
||||
RequestCountSLA: 100,
|
||||
SuccessCount: 90,
|
||||
ErrorCountTotal: 10,
|
||||
ErrorCountSLA: 10,
|
||||
|
||||
SLA: 0.90,
|
||||
ErrorRate: 0.10,
|
||||
UpstreamErrorRate: 0.08,
|
||||
|
||||
Duration: OpsPercentiles{P99: intPtr(20_000)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(2_000)},
|
||||
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(98.0),
|
||||
MemoryUsagePercent: float64Ptr(97.0),
|
||||
DBConnWaiting: intPtr(3),
|
||||
ConcurrencyQueueDepth: intPtr(10),
|
||||
},
|
||||
JobHeartbeats: []*OpsJobHeartbeat{
|
||||
{
|
||||
JobName: "job-a",
|
||||
LastErrorAt: timePtr(time.Now().UTC().Add(-1 * time.Minute)),
|
||||
LastError: stringPtr("boom"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
score := computeDashboardHealthScore(time.Now().UTC(), ov)
|
||||
require.Less(t, score, 80)
|
||||
require.GreaterOrEqual(t, score, 0)
|
||||
}
|
||||
|
||||
func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
overview *OpsDashboardOverview
|
||||
wantMin int
|
||||
wantMax int
|
||||
}{
|
||||
{
|
||||
name: "nil overview returns 0",
|
||||
overview: nil,
|
||||
wantMin: 0,
|
||||
wantMax: 0,
|
||||
},
|
||||
{
|
||||
name: "perfect health",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 1.0,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(100)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "good health - SLA 99.8%",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.998,
|
||||
ErrorRate: 0.003,
|
||||
UpstreamErrorRate: 0.001,
|
||||
Duration: OpsPercentiles{P99: intPtr(800)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(200)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(50),
|
||||
MemoryUsagePercent: float64Ptr(60),
|
||||
},
|
||||
},
|
||||
wantMin: 95,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "medium health - SLA 96%",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.96,
|
||||
ErrorRate: 0.02,
|
||||
UpstreamErrorRate: 0.01,
|
||||
Duration: OpsPercentiles{P99: intPtr(3000)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(600)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(70),
|
||||
MemoryUsagePercent: float64Ptr(75),
|
||||
},
|
||||
},
|
||||
wantMin: 96,
|
||||
wantMax: 97,
|
||||
},
|
||||
{
|
||||
name: "DB failure",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 70,
|
||||
wantMax: 90,
|
||||
},
|
||||
{
|
||||
name: "Redis failure",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 85,
|
||||
wantMax: 95,
|
||||
},
|
||||
{
|
||||
name: "high CPU usage",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(95),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 85,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "combined failures - business degraded + infra healthy",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.90,
|
||||
ErrorRate: 0.05,
|
||||
UpstreamErrorRate: 0.02,
|
||||
Duration: OpsPercentiles{P99: intPtr(10000)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(20),
|
||||
MemoryUsagePercent: float64Ptr(30),
|
||||
},
|
||||
},
|
||||
wantMin: 84,
|
||||
wantMax: 85,
|
||||
},
|
||||
{
|
||||
name: "combined failures - business healthy + infra degraded",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.998,
|
||||
ErrorRate: 0.001,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(600)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(95),
|
||||
MemoryUsagePercent: float64Ptr(95),
|
||||
},
|
||||
},
|
||||
wantMin: 70,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := computeDashboardHealthScore(time.Now().UTC(), tt.overview)
|
||||
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %d", tt.wantMin)
|
||||
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %d", tt.wantMax)
|
||||
require.GreaterOrEqual(t, score, 0, "score must be >= 0")
|
||||
require.LessOrEqual(t, score, 100, "score must be <= 100")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeBusinessHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
overview *OpsDashboardOverview
|
||||
wantMin float64
|
||||
wantMax float64
|
||||
}{
|
||||
{
|
||||
name: "perfect metrics",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 1.0,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "SLA boundary 99.5%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "SLA boundary 95%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.95,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "error rate boundary 1%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.99,
|
||||
ErrorRate: 0.01,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "error rate 5%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.95,
|
||||
ErrorRate: 0.05,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 77,
|
||||
wantMax: 78,
|
||||
},
|
||||
{
|
||||
name: "TTFT boundary 2s",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.99,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
TTFT: OpsPercentiles{P99: intPtr(2000)},
|
||||
},
|
||||
wantMin: 75,
|
||||
wantMax: 75,
|
||||
},
|
||||
{
|
||||
name: "upstream error dominates",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0.001,
|
||||
UpstreamErrorRate: 0.03,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 88,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := computeBusinessHealth(tt.overview)
|
||||
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %.1f", tt.wantMin)
|
||||
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %.1f", tt.wantMax)
|
||||
require.GreaterOrEqual(t, score, 0.0, "score must be >= 0")
|
||||
require.LessOrEqual(t, score, 100.0, "score must be <= 100")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeInfraHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
overview *OpsDashboardOverview
|
||||
wantMin float64
|
||||
wantMax float64
|
||||
}{
|
||||
{
|
||||
name: "all infrastructure healthy",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "DB down",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 50,
|
||||
wantMax: 70,
|
||||
},
|
||||
{
|
||||
name: "Redis down",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 80,
|
||||
wantMax: 95,
|
||||
},
|
||||
{
|
||||
name: "CPU at 90%",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(90),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 85,
|
||||
wantMax: 95,
|
||||
},
|
||||
{
|
||||
name: "failed background job",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
JobHeartbeats: []*OpsJobHeartbeat{
|
||||
{
|
||||
JobName: "test-job",
|
||||
LastErrorAt: &now,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantMin: 70,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := computeInfraHealth(now, tt.overview)
|
||||
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %.1f", tt.wantMin)
|
||||
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %.1f", tt.wantMax)
|
||||
require.GreaterOrEqual(t, score, 0.0, "score must be >= 0")
|
||||
require.LessOrEqual(t, score, 100.0, "score must be <= 100")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func timePtr(v time.Time) *time.Time { return &v }
|
||||
|
||||
func stringPtr(v string) *string { return &v }
|
||||
26
backend/internal/service/ops_histograms.go
Normal file
26
backend/internal/service/ops_histograms.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetLatencyHistogram(ctx, filter)
|
||||
}
|
||||
920
backend/internal/service/ops_metrics_collector.go
Normal file
920
backend/internal/service/ops_metrics_collector.go
Normal file
@@ -0,0 +1,920 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
)
|
||||
|
||||
const (
|
||||
opsMetricsCollectorJobName = "ops_metrics_collector"
|
||||
opsMetricsCollectorMinInterval = 60 * time.Second
|
||||
opsMetricsCollectorMaxInterval = 1 * time.Hour
|
||||
|
||||
opsMetricsCollectorTimeout = 10 * time.Second
|
||||
|
||||
opsMetricsCollectorLeaderLockKey = "ops:metrics:collector:leader"
|
||||
opsMetricsCollectorLeaderLockTTL = 90 * time.Second
|
||||
|
||||
opsMetricsCollectorHeartbeatTimeout = 2 * time.Second
|
||||
|
||||
bytesPerMB = 1024 * 1024
|
||||
)
|
||||
|
||||
var opsMetricsCollectorAdvisoryLockID = hashAdvisoryLockID(opsMetricsCollectorLeaderLockKey)
|
||||
|
||||
type OpsMetricsCollector struct {
|
||||
opsRepo OpsRepository
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
|
||||
accountRepo AccountRepository
|
||||
concurrencyService *ConcurrencyService
|
||||
|
||||
db *sql.DB
|
||||
redisClient *redis.Client
|
||||
instanceID string
|
||||
|
||||
lastCgroupCPUUsageNanos uint64
|
||||
lastCgroupCPUSampleAt time.Time
|
||||
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
skipLogMu sync.Mutex
|
||||
skipLogAt time.Time
|
||||
}
|
||||
|
||||
func NewOpsMetricsCollector(
|
||||
opsRepo OpsRepository,
|
||||
settingRepo SettingRepository,
|
||||
accountRepo AccountRepository,
|
||||
concurrencyService *ConcurrencyService,
|
||||
db *sql.DB,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsMetricsCollector {
|
||||
return &OpsMetricsCollector{
|
||||
opsRepo: opsRepo,
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
accountRepo: accountRepo,
|
||||
concurrencyService: concurrencyService,
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
instanceID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Start() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.startOnce.Do(func() {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
}
|
||||
go c.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Stop() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.stopOnce.Do(func() {
|
||||
if c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) run() {
|
||||
// First run immediately so the dashboard has data soon after startup.
|
||||
c.collectOnce()
|
||||
|
||||
for {
|
||||
interval := c.getInterval()
|
||||
timer := time.NewTimer(interval)
|
||||
select {
|
||||
case <-timer.C:
|
||||
c.collectOnce()
|
||||
case <-c.stopCh:
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) getInterval() time.Duration {
|
||||
interval := opsMetricsCollectorMinInterval
|
||||
|
||||
if c.settingRepo == nil {
|
||||
return interval
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
raw, err := c.settingRepo.GetValue(ctx, SettingKeyOpsMetricsIntervalSeconds)
|
||||
if err != nil {
|
||||
return interval
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return interval
|
||||
}
|
||||
|
||||
seconds, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return interval
|
||||
}
|
||||
if seconds < int(opsMetricsCollectorMinInterval.Seconds()) {
|
||||
seconds = int(opsMetricsCollectorMinInterval.Seconds())
|
||||
}
|
||||
if seconds > int(opsMetricsCollectorMaxInterval.Seconds()) {
|
||||
seconds = int(opsMetricsCollectorMaxInterval.Seconds())
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectOnce() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if c.cfg != nil && !c.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if c.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if c.db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectorTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !c.isMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := c.tryAcquireLeaderLock(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
err := c.collectAndPersist(ctx)
|
||||
finishedAt := time.Now().UTC()
|
||||
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
dur := durationMs
|
||||
runAt := startedAt
|
||||
|
||||
if err != nil {
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
errAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), opsMetricsCollectorHeartbeatTimeout)
|
||||
defer hbCancel()
|
||||
_ = c.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsMetricsCollectorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &errAt,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
log.Printf("[OpsMetricsCollector] collect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), opsMetricsCollectorHeartbeatTimeout)
|
||||
defer hbCancel()
|
||||
_ = c.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsMetricsCollectorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) isMonitoringEnabled(ctx context.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
if c.cfg != nil && !c.cfg.Ops.Enabled {
|
||||
return false
|
||||
}
|
||||
if c.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
value, err := c.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
// Fail-open: collector should not become a hard dependency.
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Align to stable minute boundaries to avoid partial buckets and to maximize cache hits.
|
||||
now := time.Now().UTC()
|
||||
windowEnd := now.Truncate(time.Minute)
|
||||
windowStart := windowEnd.Add(-1 * time.Minute)
|
||||
|
||||
sys, err := c.collectSystemStats(ctx)
|
||||
if err != nil {
|
||||
// Continue; system stats are best-effort.
|
||||
log.Printf("[OpsMetricsCollector] system stats error: %v", err)
|
||||
}
|
||||
|
||||
dbOK := c.checkDB(ctx)
|
||||
redisOK := c.checkRedis(ctx)
|
||||
active, idle := c.dbPoolStats()
|
||||
redisTotal, redisIdle, redisStatsOK := c.redisPoolStats()
|
||||
|
||||
successCount, tokenConsumed, err := c.queryUsageCounts(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query usage counts: %w", err)
|
||||
}
|
||||
|
||||
duration, ttft, err := c.queryUsageLatency(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query usage latency: %w", err)
|
||||
}
|
||||
|
||||
errorTotal, businessLimited, errorSLA, upstreamExcl, upstream429, upstream529, err := c.queryErrorCounts(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query error counts: %w", err)
|
||||
}
|
||||
|
||||
windowSeconds := windowEnd.Sub(windowStart).Seconds()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 60
|
||||
}
|
||||
requestTotal := successCount + errorTotal
|
||||
qps := float64(requestTotal) / windowSeconds
|
||||
tps := float64(tokenConsumed) / windowSeconds
|
||||
|
||||
goroutines := runtime.NumGoroutine()
|
||||
concurrencyQueueDepth := c.collectConcurrencyQueueDepth(ctx)
|
||||
|
||||
input := &OpsInsertSystemMetricsInput{
|
||||
CreatedAt: windowEnd,
|
||||
WindowMinutes: 1,
|
||||
|
||||
SuccessCount: successCount,
|
||||
ErrorCountTotal: errorTotal,
|
||||
BusinessLimitedCount: businessLimited,
|
||||
ErrorCountSLA: errorSLA,
|
||||
|
||||
UpstreamErrorCountExcl429529: upstreamExcl,
|
||||
Upstream429Count: upstream429,
|
||||
Upstream529Count: upstream529,
|
||||
|
||||
TokenConsumed: tokenConsumed,
|
||||
QPS: float64Ptr(roundTo1DP(qps)),
|
||||
TPS: float64Ptr(roundTo1DP(tps)),
|
||||
|
||||
DurationP50Ms: duration.p50,
|
||||
DurationP90Ms: duration.p90,
|
||||
DurationP95Ms: duration.p95,
|
||||
DurationP99Ms: duration.p99,
|
||||
DurationAvgMs: duration.avg,
|
||||
DurationMaxMs: duration.max,
|
||||
|
||||
TTFTP50Ms: ttft.p50,
|
||||
TTFTP90Ms: ttft.p90,
|
||||
TTFTP95Ms: ttft.p95,
|
||||
TTFTP99Ms: ttft.p99,
|
||||
TTFTAvgMs: ttft.avg,
|
||||
TTFTMaxMs: ttft.max,
|
||||
|
||||
CPUUsagePercent: sys.cpuUsagePercent,
|
||||
MemoryUsedMB: sys.memoryUsedMB,
|
||||
MemoryTotalMB: sys.memoryTotalMB,
|
||||
MemoryUsagePercent: sys.memoryUsagePercent,
|
||||
|
||||
DBOK: boolPtr(dbOK),
|
||||
RedisOK: boolPtr(redisOK),
|
||||
|
||||
RedisConnTotal: func() *int {
|
||||
if !redisStatsOK {
|
||||
return nil
|
||||
}
|
||||
return intPtr(redisTotal)
|
||||
}(),
|
||||
RedisConnIdle: func() *int {
|
||||
if !redisStatsOK {
|
||||
return nil
|
||||
}
|
||||
return intPtr(redisIdle)
|
||||
}(),
|
||||
|
||||
DBConnActive: intPtr(active),
|
||||
DBConnIdle: intPtr(idle),
|
||||
GoroutineCount: intPtr(goroutines),
|
||||
ConcurrencyQueueDepth: concurrencyQueueDepth,
|
||||
}
|
||||
|
||||
return c.opsRepo.InsertSystemMetrics(ctx, input)
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Context) *int {
|
||||
if c == nil || c.accountRepo == nil || c.concurrencyService == nil {
|
||||
return nil
|
||||
}
|
||||
if parentCtx == nil {
|
||||
parentCtx = context.Background()
|
||||
}
|
||||
|
||||
// Best-effort: never let concurrency sampling break the metrics collector.
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := c.accountRepo.ListSchedulable(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
zero := 0
|
||||
return &zero
|
||||
}
|
||||
|
||||
batch := make([]AccountWithConcurrency, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
maxConc := acc.Concurrency
|
||||
if maxConc < 0 {
|
||||
maxConc = 0
|
||||
}
|
||||
batch = append(batch, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: maxConc,
|
||||
})
|
||||
}
|
||||
if len(batch) == 0 {
|
||||
zero := 0
|
||||
return &zero
|
||||
}
|
||||
|
||||
loadMap, err := c.concurrencyService.GetAccountsLoadBatch(ctx, batch)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, info := range loadMap {
|
||||
if info == nil || info.WaitingCount <= 0 {
|
||||
continue
|
||||
}
|
||||
total += int64(info.WaitingCount)
|
||||
}
|
||||
if total < 0 {
|
||||
total = 0
|
||||
}
|
||||
|
||||
maxInt := int64(^uint(0) >> 1)
|
||||
if total > maxInt {
|
||||
total = maxInt
|
||||
}
|
||||
v := int(total)
|
||||
return &v
|
||||
}
|
||||
|
||||
type opsCollectedPercentiles struct {
|
||||
p50 *int
|
||||
p90 *int
|
||||
p95 *int
|
||||
p99 *int
|
||||
avg *float64
|
||||
max *int
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryUsageCounts(ctx context.Context, start, end time.Time) (successCount int64, tokenConsumed int64, err error) {
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(COUNT(*), 0) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2`
|
||||
|
||||
var tokens sql.NullInt64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&successCount, &tokens); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
return successCount, tokenConsumed, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryUsageLatency(ctx context.Context, start, end time.Time) (duration opsCollectedPercentiles, ttft opsCollectedPercentiles, err error) {
|
||||
{
|
||||
q := `
|
||||
SELECT
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
|
||||
AVG(duration_ms) AS avg_ms,
|
||||
MAX(duration_ms) AS max_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND duration_ms IS NOT NULL`
|
||||
|
||||
var p50, p90, p95, p99 sql.NullFloat64
|
||||
var avg sql.NullFloat64
|
||||
var max sql.NullInt64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
|
||||
return opsCollectedPercentiles{}, opsCollectedPercentiles{}, err
|
||||
}
|
||||
duration.p50 = floatToIntPtr(p50)
|
||||
duration.p90 = floatToIntPtr(p90)
|
||||
duration.p95 = floatToIntPtr(p95)
|
||||
duration.p99 = floatToIntPtr(p99)
|
||||
if avg.Valid {
|
||||
v := roundTo1DP(avg.Float64)
|
||||
duration.avg = &v
|
||||
}
|
||||
if max.Valid {
|
||||
v := int(max.Int64)
|
||||
duration.max = &v
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
q := `
|
||||
SELECT
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
|
||||
AVG(first_token_ms) AS avg_ms,
|
||||
MAX(first_token_ms) AS max_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND first_token_ms IS NOT NULL`
|
||||
|
||||
var p50, p90, p95, p99 sql.NullFloat64
|
||||
var avg sql.NullFloat64
|
||||
var max sql.NullInt64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
|
||||
return opsCollectedPercentiles{}, opsCollectedPercentiles{}, err
|
||||
}
|
||||
ttft.p50 = floatToIntPtr(p50)
|
||||
ttft.p90 = floatToIntPtr(p90)
|
||||
ttft.p95 = floatToIntPtr(p95)
|
||||
ttft.p99 = floatToIntPtr(p99)
|
||||
if avg.Valid {
|
||||
v := roundTo1DP(avg.Float64)
|
||||
ttft.avg = &v
|
||||
}
|
||||
if max.Valid {
|
||||
v := int(max.Int64)
|
||||
ttft.max = &v
|
||||
}
|
||||
}
|
||||
|
||||
return duration, ttft, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryErrorCounts(ctx context.Context, start, end time.Time) (
|
||||
errorTotal int64,
|
||||
businessLimited int64,
|
||||
errorSLA int64,
|
||||
upstreamExcl429529 int64,
|
||||
upstream429 int64,
|
||||
upstream529 int64,
|
||||
err error,
|
||||
) {
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400), 0) AS error_total,
|
||||
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited), 0) AS business_limited,
|
||||
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited), 0) AS error_sla,
|
||||
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)), 0) AS upstream_excl,
|
||||
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429), 0) AS upstream_429,
|
||||
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529), 0) AS upstream_529
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2`
|
||||
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(
|
||||
&errorTotal,
|
||||
&businessLimited,
|
||||
&errorSLA,
|
||||
&upstreamExcl429529,
|
||||
&upstream429,
|
||||
&upstream529,
|
||||
); err != nil {
|
||||
return 0, 0, 0, 0, 0, 0, err
|
||||
}
|
||||
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
|
||||
}
|
||||
|
||||
type opsCollectedSystemStats struct {
|
||||
cpuUsagePercent *float64
|
||||
memoryUsedMB *int64
|
||||
memoryTotalMB *int64
|
||||
memoryUsagePercent *float64
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) (*opsCollectedSystemStats, error) {
|
||||
out := &opsCollectedSystemStats{}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
sampleAt := time.Now().UTC()
|
||||
|
||||
// Prefer cgroup (container) metrics when available.
|
||||
if cpuPct := c.tryCgroupCPUPercent(sampleAt); cpuPct != nil {
|
||||
out.cpuUsagePercent = cpuPct
|
||||
}
|
||||
|
||||
cgroupUsed, cgroupTotal, cgroupOK := readCgroupMemoryBytes()
|
||||
if cgroupOK {
|
||||
usedMB := int64(cgroupUsed / bytesPerMB)
|
||||
out.memoryUsedMB = &usedMB
|
||||
if cgroupTotal > 0 {
|
||||
totalMB := int64(cgroupTotal / bytesPerMB)
|
||||
out.memoryTotalMB = &totalMB
|
||||
pct := roundTo1DP(float64(cgroupUsed) / float64(cgroupTotal) * 100)
|
||||
out.memoryUsagePercent = &pct
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to host metrics if cgroup metrics are unavailable (or incomplete).
|
||||
if out.cpuUsagePercent == nil {
|
||||
if cpuPercents, err := cpu.PercentWithContext(ctx, 0, false); err == nil && len(cpuPercents) > 0 {
|
||||
v := roundTo1DP(cpuPercents[0])
|
||||
out.cpuUsagePercent = &v
|
||||
}
|
||||
}
|
||||
|
||||
// If total memory isn't available from cgroup (e.g. memory.max = "max"), fill total from host.
|
||||
if out.memoryUsedMB == nil || out.memoryTotalMB == nil || out.memoryUsagePercent == nil {
|
||||
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil && vm != nil {
|
||||
if out.memoryUsedMB == nil {
|
||||
usedMB := int64(vm.Used / bytesPerMB)
|
||||
out.memoryUsedMB = &usedMB
|
||||
}
|
||||
if out.memoryTotalMB == nil {
|
||||
totalMB := int64(vm.Total / bytesPerMB)
|
||||
out.memoryTotalMB = &totalMB
|
||||
}
|
||||
if out.memoryUsagePercent == nil {
|
||||
if out.memoryUsedMB != nil && out.memoryTotalMB != nil && *out.memoryTotalMB > 0 {
|
||||
pct := roundTo1DP(float64(*out.memoryUsedMB) / float64(*out.memoryTotalMB) * 100)
|
||||
out.memoryUsagePercent = &pct
|
||||
} else {
|
||||
pct := roundTo1DP(vm.UsedPercent)
|
||||
out.memoryUsagePercent = &pct
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) tryCgroupCPUPercent(now time.Time) *float64 {
|
||||
usageNanos, ok := readCgroupCPUUsageNanos()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize baseline sample.
|
||||
if c.lastCgroupCPUSampleAt.IsZero() {
|
||||
c.lastCgroupCPUUsageNanos = usageNanos
|
||||
c.lastCgroupCPUSampleAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
elapsed := now.Sub(c.lastCgroupCPUSampleAt)
|
||||
if elapsed <= 0 {
|
||||
c.lastCgroupCPUUsageNanos = usageNanos
|
||||
c.lastCgroupCPUSampleAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
prev := c.lastCgroupCPUUsageNanos
|
||||
c.lastCgroupCPUUsageNanos = usageNanos
|
||||
c.lastCgroupCPUSampleAt = now
|
||||
|
||||
if usageNanos < prev {
|
||||
// Counter reset (container restarted).
|
||||
return nil
|
||||
}
|
||||
|
||||
deltaUsageSec := float64(usageNanos-prev) / 1e9
|
||||
elapsedSec := elapsed.Seconds()
|
||||
if elapsedSec <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cores := readCgroupCPULimitCores()
|
||||
if cores <= 0 {
|
||||
// Can't reliably normalize; skip and fall back to gopsutil.
|
||||
return nil
|
||||
}
|
||||
|
||||
pct := (deltaUsageSec / (elapsedSec * cores)) * 100
|
||||
if pct < 0 {
|
||||
pct = 0
|
||||
}
|
||||
// Clamp to avoid noise/jitter showing impossible values.
|
||||
if pct > 100 {
|
||||
pct = 100
|
||||
}
|
||||
v := roundTo1DP(pct)
|
||||
return &v
|
||||
}
|
||||
|
||||
func readCgroupMemoryBytes() (usedBytes uint64, totalBytes uint64, ok bool) {
|
||||
// cgroup v2 (most common in modern containers)
|
||||
if used, ok1 := readUintFile("/sys/fs/cgroup/memory.current"); ok1 {
|
||||
usedBytes = used
|
||||
rawMax, err := os.ReadFile("/sys/fs/cgroup/memory.max")
|
||||
if err == nil {
|
||||
s := strings.TrimSpace(string(rawMax))
|
||||
if s != "" && s != "max" {
|
||||
if v, err := strconv.ParseUint(s, 10, 64); err == nil {
|
||||
totalBytes = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return usedBytes, totalBytes, true
|
||||
}
|
||||
|
||||
// cgroup v1 fallback
|
||||
if used, ok1 := readUintFile("/sys/fs/cgroup/memory/memory.usage_in_bytes"); ok1 {
|
||||
usedBytes = used
|
||||
if limit, ok2 := readUintFile("/sys/fs/cgroup/memory/memory.limit_in_bytes"); ok2 {
|
||||
// Some environments report a very large number when unlimited.
|
||||
if limit > 0 && limit < (1<<60) {
|
||||
totalBytes = limit
|
||||
}
|
||||
}
|
||||
return usedBytes, totalBytes, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func readCgroupCPUUsageNanos() (usageNanos uint64, ok bool) {
|
||||
// cgroup v2: cpu.stat has usage_usec
|
||||
if raw, err := os.ReadFile("/sys/fs/cgroup/cpu.stat"); err == nil {
|
||||
lines := strings.Split(string(raw), "\n")
|
||||
for _, line := range lines {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 {
|
||||
continue
|
||||
}
|
||||
if fields[0] != "usage_usec" {
|
||||
continue
|
||||
}
|
||||
v, err := strconv.ParseUint(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return v * 1000, true
|
||||
}
|
||||
}
|
||||
|
||||
// cgroup v1: cpuacct.usage is in nanoseconds
|
||||
if v, ok := readUintFile("/sys/fs/cgroup/cpuacct/cpuacct.usage"); ok {
|
||||
return v, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func readCgroupCPULimitCores() float64 {
|
||||
// cgroup v2: cpu.max => "<quota> <period>" or "max <period>"
|
||||
if raw, err := os.ReadFile("/sys/fs/cgroup/cpu.max"); err == nil {
|
||||
fields := strings.Fields(string(raw))
|
||||
if len(fields) >= 2 && fields[0] != "max" {
|
||||
quota, err1 := strconv.ParseFloat(fields[0], 64)
|
||||
period, err2 := strconv.ParseFloat(fields[1], 64)
|
||||
if err1 == nil && err2 == nil && quota > 0 && period > 0 {
|
||||
return quota / period
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cgroup v1: cpu.cfs_quota_us / cpu.cfs_period_us
|
||||
quota, okQuota := readIntFile("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")
|
||||
period, okPeriod := readIntFile("/sys/fs/cgroup/cpu/cpu.cfs_period_us")
|
||||
if okQuota && okPeriod && quota > 0 && period > 0 {
|
||||
return float64(quota) / float64(period)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func readUintFile(path string) (uint64, bool) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
s := strings.TrimSpace(string(raw))
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func readIntFile(path string) (int64, bool) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
s := strings.TrimSpace(string(raw))
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) checkDB(ctx context.Context) bool {
|
||||
if c == nil || c.db == nil {
|
||||
return false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
var one int
|
||||
if err := c.db.QueryRowContext(ctx, "SELECT 1").Scan(&one); err != nil {
|
||||
return false
|
||||
}
|
||||
return one == 1
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) checkRedis(ctx context.Context) bool {
|
||||
if c == nil || c.redisClient == nil {
|
||||
return false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.redisClient.Ping(ctx).Err() == nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) redisPoolStats() (total int, idle int, ok bool) {
|
||||
if c == nil || c.redisClient == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
stats := c.redisClient.PoolStats()
|
||||
if stats == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
return int(stats.TotalConns), int(stats.IdleConns), true
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) dbPoolStats() (active int, idle int) {
|
||||
if c == nil || c.db == nil {
|
||||
return 0, 0
|
||||
}
|
||||
stats := c.db.Stats()
|
||||
return stats.InUse, stats.Idle
|
||||
}
|
||||
|
||||
var opsMetricsCollectorReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
func (c *OpsMetricsCollector) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||
if c == nil || c.redisClient == nil {
|
||||
return nil, true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
ok, err := c.redisClient.SetNX(ctx, opsMetricsCollectorLeaderLockKey, c.instanceID, opsMetricsCollectorLeaderLockTTL).Result()
|
||||
if err != nil {
|
||||
// Prefer fail-closed to avoid stampeding the database when Redis is flaky.
|
||||
// Fallback to a DB advisory lock when Redis is present but unavailable.
|
||||
release, ok := tryAcquireDBAdvisoryLock(ctx, c.db, opsMetricsCollectorAdvisoryLockID)
|
||||
if !ok {
|
||||
c.maybeLogSkip()
|
||||
return nil, false
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
if !ok {
|
||||
c.maybeLogSkip()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
release := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, _ = opsMetricsCollectorReleaseScript.Run(ctx, c.redisClient, []string{opsMetricsCollectorLeaderLockKey}, c.instanceID).Result()
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) maybeLogSkip() {
|
||||
c.skipLogMu.Lock()
|
||||
defer c.skipLogMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if !c.skipLogAt.IsZero() && now.Sub(c.skipLogAt) < time.Minute {
|
||||
return
|
||||
}
|
||||
c.skipLogAt = now
|
||||
log.Printf("[OpsMetricsCollector] leader lock held by another instance; skipping")
|
||||
}
|
||||
|
||||
func floatToIntPtr(v sql.NullFloat64) *int {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
n := int(math.Round(v.Float64))
|
||||
return &n
|
||||
}
|
||||
|
||||
func roundTo1DP(v float64) float64 {
|
||||
return math.Round(v*10) / 10
|
||||
}
|
||||
|
||||
func truncateString(s string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
cut := s[:max]
|
||||
for len(cut) > 0 && !utf8.ValidString(cut) {
|
||||
cut = cut[:len(cut)-1]
|
||||
}
|
||||
return cut
|
||||
}
|
||||
|
||||
func boolPtr(v bool) *bool {
|
||||
out := v
|
||||
return &out
|
||||
}
|
||||
|
||||
func intPtr(v int) *int {
|
||||
out := v
|
||||
return &out
|
||||
}
|
||||
|
||||
func float64Ptr(v float64) *float64 {
|
||||
out := v
|
||||
return &out
|
||||
}
|
||||
169
backend/internal/service/ops_models.go
Normal file
169
backend/internal/service/ops_models.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type OpsErrorLog struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Standardized classification
|
||||
// - phase: request|auth|routing|upstream|network|internal
|
||||
// - owner: client|provider|platform
|
||||
// - source: client_request|upstream_http|gateway
|
||||
Phase string `json:"phase"`
|
||||
Type string `json:"type"`
|
||||
|
||||
Owner string `json:"error_owner"`
|
||||
Source string `json:"error_source"`
|
||||
|
||||
Severity string `json:"severity"`
|
||||
|
||||
StatusCode int `json:"status_code"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
|
||||
IsRetryable bool `json:"is_retryable"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
|
||||
Resolved bool `json:"resolved"`
|
||||
ResolvedAt *time.Time `json:"resolved_at"`
|
||||
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
|
||||
ResolvedByUserName string `json:"resolved_by_user_name"`
|
||||
ResolvedRetryID *int64 `json:"resolved_retry_id"`
|
||||
ResolvedStatusRaw string `json:"-"`
|
||||
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Message string `json:"message"`
|
||||
|
||||
UserID *int64 `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
|
||||
ClientIP *string `json:"client_ip"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OpsErrorLogDetail struct {
|
||||
OpsErrorLog
|
||||
|
||||
ErrorBody string `json:"error_body"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
|
||||
// Upstream context (optional)
|
||||
UpstreamStatusCode *int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamErrorMessage string `json:"upstream_error_message,omitempty"`
|
||||
UpstreamErrorDetail string `json:"upstream_error_detail,omitempty"`
|
||||
UpstreamErrors string `json:"upstream_errors,omitempty"` // JSON array (string) for display/parsing
|
||||
|
||||
// Timings (optional)
|
||||
AuthLatencyMs *int64 `json:"auth_latency_ms"`
|
||||
RoutingLatencyMs *int64 `json:"routing_latency_ms"`
|
||||
UpstreamLatencyMs *int64 `json:"upstream_latency_ms"`
|
||||
ResponseLatencyMs *int64 `json:"response_latency_ms"`
|
||||
TimeToFirstTokenMs *int64 `json:"time_to_first_token_ms"`
|
||||
|
||||
// Retry context
|
||||
RequestBody string `json:"request_body"`
|
||||
RequestBodyTruncated bool `json:"request_body_truncated"`
|
||||
RequestBodyBytes *int `json:"request_body_bytes"`
|
||||
RequestHeaders string `json:"request_headers,omitempty"`
|
||||
|
||||
// vNext metric semantics
|
||||
IsBusinessLimited bool `json:"is_business_limited"`
|
||||
}
|
||||
|
||||
type OpsErrorLogFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
Platform string
|
||||
GroupID *int64
|
||||
AccountID *int64
|
||||
|
||||
StatusCodes []int
|
||||
StatusCodesOther bool
|
||||
Phase string
|
||||
Owner string
|
||||
Source string
|
||||
Resolved *bool
|
||||
Query string
|
||||
UserQuery string // Search by user email
|
||||
|
||||
// Optional correlation keys for exact matching.
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
|
||||
// View controls error categorization for list endpoints.
|
||||
// - errors: show actionable errors (exclude business-limited / 429 / 529)
|
||||
// - excluded: only show excluded errors
|
||||
// - all: show everything
|
||||
View string
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
type OpsErrorLogList struct {
|
||||
Errors []*OpsErrorLog `json:"errors"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
type OpsRetryAttempt struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
RequestedByUserID int64 `json:"requested_by_user_id"`
|
||||
SourceErrorID int64 `json:"source_error_id"`
|
||||
Mode string `json:"mode"`
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
PinnedAccountName string `json:"pinned_account_name"`
|
||||
|
||||
Status string `json:"status"`
|
||||
StartedAt *time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at"`
|
||||
DurationMs *int64 `json:"duration_ms"`
|
||||
|
||||
// Persisted execution results (best-effort)
|
||||
Success *bool `json:"success"`
|
||||
HTTPStatusCode *int `json:"http_status_code"`
|
||||
UpstreamRequestID *string `json:"upstream_request_id"`
|
||||
UsedAccountID *int64 `json:"used_account_id"`
|
||||
UsedAccountName string `json:"used_account_name"`
|
||||
ResponsePreview *string `json:"response_preview"`
|
||||
ResponseTruncated *bool `json:"response_truncated"`
|
||||
|
||||
// Optional correlation
|
||||
ResultRequestID *string `json:"result_request_id"`
|
||||
ResultErrorID *int64 `json:"result_error_id"`
|
||||
|
||||
ErrorMessage *string `json:"error_message"`
|
||||
}
|
||||
|
||||
type OpsRetryResult struct {
|
||||
AttemptID int64 `json:"attempt_id"`
|
||||
Mode string `json:"mode"`
|
||||
Status string `json:"status"`
|
||||
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
UsedAccountID *int64 `json:"used_account_id"`
|
||||
|
||||
HTTPStatusCode int `json:"http_status_code"`
|
||||
UpstreamRequestID string `json:"upstream_request_id"`
|
||||
|
||||
ResponsePreview string `json:"response_preview"`
|
||||
ResponseTruncated bool `json:"response_truncated"`
|
||||
|
||||
ErrorMessage string `json:"error_message"`
|
||||
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt time.Time `json:"finished_at"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
}
|
||||
263
backend/internal/service/ops_port.go
Normal file
263
backend/internal/service/ops_port.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpsRepository interface {
|
||||
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
||||
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
||||
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
||||
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
||||
|
||||
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
||||
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
||||
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
||||
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
|
||||
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
|
||||
|
||||
// Lightweight window stats (for realtime WS / quick sampling).
|
||||
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
||||
// Lightweight realtime traffic summary (for the Ops dashboard header card).
|
||||
GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error)
|
||||
|
||||
GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error)
|
||||
GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error)
|
||||
GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error)
|
||||
GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error)
|
||||
GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error)
|
||||
|
||||
InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error
|
||||
GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error)
|
||||
|
||||
UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error
|
||||
ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error)
|
||||
|
||||
// Alerts (rules + events)
|
||||
ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error)
|
||||
CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error)
|
||||
UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error)
|
||||
DeleteAlertRule(ctx context.Context, id int64) error
|
||||
|
||||
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
||||
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
|
||||
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
||||
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
||||
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
||||
|
||||
// Alert silences
|
||||
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
|
||||
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
|
||||
|
||||
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
||||
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||
GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error)
|
||||
GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error)
|
||||
}
|
||||
|
||||
type OpsInsertErrorLogInput struct {
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
|
||||
UserID *int64
|
||||
APIKeyID *int64
|
||||
AccountID *int64
|
||||
GroupID *int64
|
||||
ClientIP *string
|
||||
|
||||
Platform string
|
||||
Model string
|
||||
RequestPath string
|
||||
Stream bool
|
||||
UserAgent string
|
||||
|
||||
ErrorPhase string
|
||||
ErrorType string
|
||||
Severity string
|
||||
StatusCode int
|
||||
IsBusinessLimited bool
|
||||
IsCountTokens bool // 是否为 count_tokens 请求
|
||||
|
||||
ErrorMessage string
|
||||
ErrorBody string
|
||||
|
||||
ErrorSource string
|
||||
ErrorOwner string
|
||||
|
||||
UpstreamStatusCode *int
|
||||
UpstreamErrorMessage *string
|
||||
UpstreamErrorDetail *string
|
||||
// UpstreamErrors captures all upstream error attempts observed during handling this request.
|
||||
// It is populated during request processing (gin context) and sanitized+serialized by OpsService.
|
||||
UpstreamErrors []*OpsUpstreamErrorEvent
|
||||
// UpstreamErrorsJSON is the sanitized JSON string stored into ops_error_logs.upstream_errors.
|
||||
// It is set by OpsService.RecordError before persisting.
|
||||
UpstreamErrorsJSON *string
|
||||
|
||||
TimeToFirstTokenMs *int64
|
||||
|
||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||
RequestBodyTruncated bool
|
||||
RequestBodyBytes *int
|
||||
RequestHeadersJSON *string // optional json string
|
||||
|
||||
IsRetryable bool
|
||||
RetryCount int
|
||||
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type OpsInsertRetryAttemptInput struct {
|
||||
RequestedByUserID int64
|
||||
SourceErrorID int64
|
||||
Mode string
|
||||
PinnedAccountID *int64
|
||||
|
||||
// running|queued etc.
|
||||
Status string
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
type OpsUpdateRetryAttemptInput struct {
|
||||
ID int64
|
||||
|
||||
// succeeded|failed
|
||||
Status string
|
||||
FinishedAt time.Time
|
||||
DurationMs int64
|
||||
|
||||
// Persisted execution results (best-effort)
|
||||
Success *bool
|
||||
HTTPStatusCode *int
|
||||
UpstreamRequestID *string
|
||||
UsedAccountID *int64
|
||||
ResponsePreview *string
|
||||
ResponseTruncated *bool
|
||||
|
||||
// Optional correlation (legacy fields kept)
|
||||
ResultRequestID *string
|
||||
ResultErrorID *int64
|
||||
|
||||
ErrorMessage *string
|
||||
}
|
||||
|
||||
type OpsInsertSystemMetricsInput struct {
|
||||
CreatedAt time.Time
|
||||
WindowMinutes int
|
||||
|
||||
Platform *string
|
||||
GroupID *int64
|
||||
|
||||
SuccessCount int64
|
||||
ErrorCountTotal int64
|
||||
BusinessLimitedCount int64
|
||||
ErrorCountSLA int64
|
||||
|
||||
UpstreamErrorCountExcl429529 int64
|
||||
Upstream429Count int64
|
||||
Upstream529Count int64
|
||||
|
||||
TokenConsumed int64
|
||||
|
||||
QPS *float64
|
||||
TPS *float64
|
||||
|
||||
DurationP50Ms *int
|
||||
DurationP90Ms *int
|
||||
DurationP95Ms *int
|
||||
DurationP99Ms *int
|
||||
DurationAvgMs *float64
|
||||
DurationMaxMs *int
|
||||
|
||||
TTFTP50Ms *int
|
||||
TTFTP90Ms *int
|
||||
TTFTP95Ms *int
|
||||
TTFTP99Ms *int
|
||||
TTFTAvgMs *float64
|
||||
TTFTMaxMs *int
|
||||
|
||||
CPUUsagePercent *float64
|
||||
MemoryUsedMB *int64
|
||||
MemoryTotalMB *int64
|
||||
MemoryUsagePercent *float64
|
||||
|
||||
DBOK *bool
|
||||
RedisOK *bool
|
||||
|
||||
RedisConnTotal *int
|
||||
RedisConnIdle *int
|
||||
|
||||
DBConnActive *int
|
||||
DBConnIdle *int
|
||||
DBConnWaiting *int
|
||||
|
||||
GoroutineCount *int
|
||||
ConcurrencyQueueDepth *int
|
||||
}
|
||||
|
||||
type OpsSystemMetricsSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
WindowMinutes int `json:"window_minutes"`
|
||||
|
||||
CPUUsagePercent *float64 `json:"cpu_usage_percent"`
|
||||
MemoryUsedMB *int64 `json:"memory_used_mb"`
|
||||
MemoryTotalMB *int64 `json:"memory_total_mb"`
|
||||
MemoryUsagePercent *float64 `json:"memory_usage_percent"`
|
||||
|
||||
DBOK *bool `json:"db_ok"`
|
||||
RedisOK *bool `json:"redis_ok"`
|
||||
|
||||
// Config-derived limits (best-effort). These are not historical metrics; they help UI render "current vs max".
|
||||
DBMaxOpenConns *int `json:"db_max_open_conns"`
|
||||
RedisPoolSize *int `json:"redis_pool_size"`
|
||||
|
||||
RedisConnTotal *int `json:"redis_conn_total"`
|
||||
RedisConnIdle *int `json:"redis_conn_idle"`
|
||||
|
||||
DBConnActive *int `json:"db_conn_active"`
|
||||
DBConnIdle *int `json:"db_conn_idle"`
|
||||
DBConnWaiting *int `json:"db_conn_waiting"`
|
||||
|
||||
GoroutineCount *int `json:"goroutine_count"`
|
||||
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
|
||||
}
|
||||
|
||||
type OpsUpsertJobHeartbeatInput struct {
|
||||
JobName string
|
||||
|
||||
LastRunAt *time.Time
|
||||
LastSuccessAt *time.Time
|
||||
LastErrorAt *time.Time
|
||||
LastError *string
|
||||
LastDurationMs *int64
|
||||
|
||||
// LastResult is an optional human-readable summary of the last successful run.
|
||||
LastResult *string
|
||||
}
|
||||
|
||||
type OpsJobHeartbeat struct {
|
||||
JobName string `json:"job_name"`
|
||||
|
||||
LastRunAt *time.Time `json:"last_run_at"`
|
||||
LastSuccessAt *time.Time `json:"last_success_at"`
|
||||
LastErrorAt *time.Time `json:"last_error_at"`
|
||||
LastError *string `json:"last_error"`
|
||||
LastDurationMs *int64 `json:"last_duration_ms"`
|
||||
LastResult *string `json:"last_result"`
|
||||
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type OpsWindowStats struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
ErrorCountTotal int64 `json:"error_count_total"`
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
}
|
||||
40
backend/internal/service/ops_query_mode.go
Normal file
40
backend/internal/service/ops_query_mode.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type OpsQueryMode string
|
||||
|
||||
const (
|
||||
OpsQueryModeAuto OpsQueryMode = "auto"
|
||||
OpsQueryModeRaw OpsQueryMode = "raw"
|
||||
OpsQueryModePreagg OpsQueryMode = "preagg"
|
||||
)
|
||||
|
||||
// ErrOpsPreaggregatedNotPopulated indicates that raw logs exist for a window, but the
|
||||
// pre-aggregation tables are not populated yet. This is primarily used to implement
|
||||
// the forced `preagg` mode UX.
|
||||
var ErrOpsPreaggregatedNotPopulated = errors.New("ops pre-aggregated tables not populated")
|
||||
|
||||
func ParseOpsQueryMode(raw string) OpsQueryMode {
|
||||
v := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch v {
|
||||
case string(OpsQueryModeRaw):
|
||||
return OpsQueryModeRaw
|
||||
case string(OpsQueryModePreagg):
|
||||
return OpsQueryModePreagg
|
||||
default:
|
||||
return OpsQueryModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
func (m OpsQueryMode) IsValid() bool {
|
||||
switch m {
|
||||
case OpsQueryModeAuto, OpsQueryModeRaw, OpsQueryModePreagg:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
36
backend/internal/service/ops_realtime.go
Normal file
36
backend/internal/service/ops_realtime.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IsRealtimeMonitoringEnabled returns true when realtime ops features are enabled.
|
||||
//
|
||||
// This is a soft switch controlled by the DB setting `ops_realtime_monitoring_enabled`,
|
||||
// and it is also gated by the hard switch/soft switch of overall ops monitoring.
|
||||
func (s *OpsService) IsRealtimeMonitoringEnabled(ctx context.Context) bool {
|
||||
if !s.IsMonitoringEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
if s.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRealtimeMonitoringEnabled)
|
||||
if err != nil {
|
||||
// Default enabled when key is missing; fail-open on transient errors.
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
81
backend/internal/service/ops_realtime_models.go
Normal file
81
backend/internal/service/ops_realtime_models.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// PlatformConcurrencyInfo aggregates concurrency usage by platform.
|
||||
type PlatformConcurrencyInfo struct {
|
||||
Platform string `json:"platform"`
|
||||
CurrentInUse int64 `json:"current_in_use"`
|
||||
MaxCapacity int64 `json:"max_capacity"`
|
||||
LoadPercentage float64 `json:"load_percentage"`
|
||||
WaitingInQueue int64 `json:"waiting_in_queue"`
|
||||
}
|
||||
|
||||
// GroupConcurrencyInfo aggregates concurrency usage by group.
|
||||
//
|
||||
// Note: one account can belong to multiple groups; group totals are therefore not additive across groups.
|
||||
type GroupConcurrencyInfo struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Platform string `json:"platform"`
|
||||
CurrentInUse int64 `json:"current_in_use"`
|
||||
MaxCapacity int64 `json:"max_capacity"`
|
||||
LoadPercentage float64 `json:"load_percentage"`
|
||||
WaitingInQueue int64 `json:"waiting_in_queue"`
|
||||
}
|
||||
|
||||
// AccountConcurrencyInfo represents real-time concurrency usage for a single account.
|
||||
type AccountConcurrencyInfo struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
CurrentInUse int64 `json:"current_in_use"`
|
||||
MaxCapacity int64 `json:"max_capacity"`
|
||||
LoadPercentage float64 `json:"load_percentage"`
|
||||
WaitingInQueue int64 `json:"waiting_in_queue"`
|
||||
}
|
||||
|
||||
// PlatformAvailability aggregates account availability by platform.
|
||||
type PlatformAvailability struct {
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
}
|
||||
|
||||
// GroupAvailability aggregates account availability by group.
|
||||
type GroupAvailability struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
Platform string `json:"platform"`
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
AvailableCount int64 `json:"available_count"`
|
||||
RateLimitCount int64 `json:"rate_limit_count"`
|
||||
ErrorCount int64 `json:"error_count"`
|
||||
}
|
||||
|
||||
// AccountAvailability represents current availability for a single account.
|
||||
type AccountAvailability struct {
|
||||
AccountID int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
|
||||
Status string `json:"status"`
|
||||
|
||||
IsAvailable bool `json:"is_available"`
|
||||
IsRateLimited bool `json:"is_rate_limited"`
|
||||
IsOverloaded bool `json:"is_overloaded"`
|
||||
HasError bool `json:"has_error"`
|
||||
|
||||
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
|
||||
RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"`
|
||||
OverloadUntil *time.Time `json:"overload_until"`
|
||||
OverloadRemainingSec *int64 `json:"overload_remaining_sec"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
|
||||
}
|
||||
36
backend/internal/service/ops_realtime_traffic.go
Normal file
36
backend/internal/service/ops_realtime_traffic.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the provided window.
|
||||
// This is used by the Ops dashboard "Realtime Traffic" card and is intentionally lightweight.
|
||||
func (s *OpsService) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
if filter.EndTime.Sub(filter.StartTime) > time.Hour {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_TOO_LARGE", "invalid time range: max window is 1 hour")
|
||||
}
|
||||
|
||||
// Realtime traffic summary always uses raw logs (minute granularity peaks).
|
||||
filter.QueryMode = OpsQueryModeRaw
|
||||
|
||||
return s.opsRepo.GetRealtimeTrafficSummary(ctx, filter)
|
||||
}
|
||||
19
backend/internal/service/ops_realtime_traffic_models.go
Normal file
19
backend/internal/service/ops_realtime_traffic_models.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// OpsRealtimeTrafficSummary is a lightweight summary used by the Ops dashboard "Realtime Traffic" card.
|
||||
// It reports QPS/TPS current/peak/avg for the requested time window.
|
||||
type OpsRealtimeTrafficSummary struct {
|
||||
// Window is a normalized label (e.g. "1min", "5min", "30min", "1h").
|
||||
Window string `json:"window"`
|
||||
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
|
||||
QPS OpsRateSummary `json:"qps"`
|
||||
TPS OpsRateSummary `json:"tps"`
|
||||
}
|
||||
151
backend/internal/service/ops_request_details.go
Normal file
151
backend/internal/service/ops_request_details.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpsRequestKind string
|
||||
|
||||
const (
|
||||
OpsRequestKindSuccess OpsRequestKind = "success"
|
||||
OpsRequestKindError OpsRequestKind = "error"
|
||||
)
|
||||
|
||||
// OpsRequestDetail is a request-level view across success (usage_logs) and error (ops_error_logs).
|
||||
// It powers "request drilldown" UIs without exposing full request bodies for successful requests.
|
||||
type OpsRequestDetail struct {
|
||||
Kind OpsRequestKind `json:"kind"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RequestID string `json:"request_id"`
|
||||
|
||||
Platform string `json:"platform,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
|
||||
DurationMs *int `json:"duration_ms,omitempty"`
|
||||
StatusCode *int `json:"status_code,omitempty"`
|
||||
|
||||
// When Kind == "error", ErrorID links to /admin/ops/errors/:id.
|
||||
ErrorID *int64 `json:"error_id,omitempty"`
|
||||
|
||||
Phase string `json:"phase,omitempty"`
|
||||
Severity string `json:"severity,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
UserID *int64 `json:"user_id,omitempty"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OpsRequestDetailFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
// kind: success|error|all
|
||||
Kind string
|
||||
|
||||
Platform string
|
||||
GroupID *int64
|
||||
|
||||
UserID *int64
|
||||
APIKeyID *int64
|
||||
AccountID *int64
|
||||
|
||||
Model string
|
||||
RequestID string
|
||||
Query string
|
||||
|
||||
MinDurationMs *int
|
||||
MaxDurationMs *int
|
||||
|
||||
// Sort: created_at_desc (default) or duration_desc.
|
||||
Sort string
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
func (f *OpsRequestDetailFilter) Normalize() (page, pageSize int, startTime, endTime time.Time) {
|
||||
page = 1
|
||||
pageSize = 50
|
||||
endTime = time.Now()
|
||||
startTime = endTime.Add(-1 * time.Hour)
|
||||
|
||||
if f == nil {
|
||||
return page, pageSize, startTime, endTime
|
||||
}
|
||||
|
||||
if f.Page > 0 {
|
||||
page = f.Page
|
||||
}
|
||||
if f.PageSize > 0 {
|
||||
pageSize = f.PageSize
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
|
||||
if f.EndTime != nil {
|
||||
endTime = *f.EndTime
|
||||
}
|
||||
if f.StartTime != nil {
|
||||
startTime = *f.StartTime
|
||||
} else if f.EndTime != nil {
|
||||
startTime = endTime.Add(-1 * time.Hour)
|
||||
}
|
||||
|
||||
if startTime.After(endTime) {
|
||||
startTime, endTime = endTime, startTime
|
||||
}
|
||||
|
||||
return page, pageSize, startTime, endTime
|
||||
}
|
||||
|
||||
type OpsRequestDetailList struct {
|
||||
Items []*OpsRequestDetail `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
func (s *OpsService) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) (*OpsRequestDetailList, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return &OpsRequestDetailList{
|
||||
Items: []*OpsRequestDetail{},
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 50,
|
||||
}, nil
|
||||
}
|
||||
|
||||
page, pageSize, startTime, endTime := filter.Normalize()
|
||||
filterCopy := &OpsRequestDetailFilter{}
|
||||
if filter != nil {
|
||||
*filterCopy = *filter
|
||||
}
|
||||
filterCopy.Page = page
|
||||
filterCopy.PageSize = pageSize
|
||||
filterCopy.StartTime = &startTime
|
||||
filterCopy.EndTime = &endTime
|
||||
|
||||
items, total, err := s.opsRepo.ListRequestDetails(ctx, filterCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if items == nil {
|
||||
items = []*OpsRequestDetail{}
|
||||
}
|
||||
|
||||
return &OpsRequestDetailList{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
720
backend/internal/service/ops_retry.go
Normal file
720
backend/internal/service/ops_retry.go
Normal file
@@ -0,0 +1,720 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
OpsRetryModeClient = "client"
|
||||
OpsRetryModeUpstream = "upstream"
|
||||
)
|
||||
|
||||
const (
|
||||
opsRetryStatusRunning = "running"
|
||||
opsRetryStatusSucceeded = "succeeded"
|
||||
opsRetryStatusFailed = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
opsRetryTimeout = 60 * time.Second
|
||||
opsRetryCaptureBytesLimit = 64 * 1024
|
||||
opsRetryResponsePreviewMax = 8 * 1024
|
||||
opsRetryMinIntervalPerError = 10 * time.Second
|
||||
opsRetryMaxAccountSwitches = 3
|
||||
)
|
||||
|
||||
var opsRetryRequestHeaderAllowlist = map[string]bool{
|
||||
"anthropic-beta": true,
|
||||
"anthropic-version": true,
|
||||
}
|
||||
|
||||
type opsRetryRequestType string
|
||||
|
||||
const (
|
||||
opsRetryTypeMessages opsRetryRequestType = "messages"
|
||||
opsRetryTypeOpenAI opsRetryRequestType = "openai_responses"
|
||||
opsRetryTypeGeminiV1B opsRetryRequestType = "gemini_v1beta"
|
||||
)
|
||||
|
||||
type limitedResponseWriter struct {
|
||||
header http.Header
|
||||
wroteHeader bool
|
||||
|
||||
limit int
|
||||
totalWritten int64
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func newLimitedResponseWriter(limit int) *limitedResponseWriter {
|
||||
if limit <= 0 {
|
||||
limit = 1
|
||||
}
|
||||
return &limitedResponseWriter{
|
||||
header: make(http.Header),
|
||||
limit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *limitedResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *limitedResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.wroteHeader {
|
||||
return
|
||||
}
|
||||
w.wroteHeader = true
|
||||
}
|
||||
|
||||
func (w *limitedResponseWriter) Write(p []byte) (int, error) {
|
||||
if !w.wroteHeader {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
w.totalWritten += int64(len(p))
|
||||
|
||||
if w.buf.Len() < w.limit {
|
||||
remaining := w.limit - w.buf.Len()
|
||||
if len(p) > remaining {
|
||||
_, _ = w.buf.Write(p[:remaining])
|
||||
} else {
|
||||
_, _ = w.buf.Write(p)
|
||||
}
|
||||
}
|
||||
|
||||
// Pretend we wrote everything to avoid upstream/client code treating it as an error.
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *limitedResponseWriter) Flush() {}
|
||||
|
||||
func (w *limitedResponseWriter) bodyBytes() []byte {
|
||||
return w.buf.Bytes()
|
||||
}
|
||||
|
||||
func (w *limitedResponseWriter) truncated() bool {
|
||||
return w.totalWritten > int64(w.limit)
|
||||
}
|
||||
|
||||
const (
|
||||
OpsRetryModeUpstreamEvent = "upstream_event"
|
||||
)
|
||||
|
||||
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
|
||||
mode = strings.ToLower(strings.TrimSpace(mode))
|
||||
switch mode {
|
||||
case OpsRetryModeClient, OpsRetryModeUpstream:
|
||||
default:
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
|
||||
}
|
||||
|
||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if errorLog == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||
}
|
||||
|
||||
var pinned *int64
|
||||
if mode == OpsRetryModeUpstream {
|
||||
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||
pinned = pinnedAccountID
|
||||
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||
pinned = errorLog.AccountID
|
||||
} else {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
||||
}
|
||||
}
|
||||
|
||||
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
|
||||
}
|
||||
|
||||
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
|
||||
// idx is 0-based. It always pins the original event account_id.
|
||||
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if idx < 0 {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
|
||||
}
|
||||
|
||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if errorLog == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
|
||||
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
|
||||
if err != nil {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
|
||||
}
|
||||
if idx >= len(events) {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
|
||||
}
|
||||
ev := events[idx]
|
||||
if ev == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
|
||||
}
|
||||
if ev.AccountID <= 0 {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||
}
|
||||
|
||||
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
|
||||
if upstreamBody == "" {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
|
||||
}
|
||||
|
||||
override := *errorLog
|
||||
override.RequestBody = upstreamBody
|
||||
pinned := ev.AccountID
|
||||
|
||||
// Persist as upstream_event, execute as upstream pinned retry.
|
||||
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
|
||||
}
|
||||
|
||||
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
|
||||
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
|
||||
}
|
||||
if latest != nil {
|
||||
if strings.EqualFold(latest.Status, opsRetryStatusRunning) || strings.EqualFold(latest.Status, "queued") {
|
||||
return nil, infraerrors.Conflict("OPS_RETRY_IN_PROGRESS", "A retry is already in progress for this error")
|
||||
}
|
||||
|
||||
lastAttemptAt := latest.CreatedAt
|
||||
if latest.FinishedAt != nil && !latest.FinishedAt.IsZero() {
|
||||
lastAttemptAt = *latest.FinishedAt
|
||||
} else if latest.StartedAt != nil && !latest.StartedAt.IsZero() {
|
||||
lastAttemptAt = *latest.StartedAt
|
||||
}
|
||||
|
||||
if time.Since(lastAttemptAt) < opsRetryMinIntervalPerError {
|
||||
return nil, infraerrors.Conflict("OPS_RETRY_TOO_FREQUENT", "Please wait before retrying this error again")
|
||||
}
|
||||
}
|
||||
|
||||
if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||
}
|
||||
|
||||
var pinned *int64
|
||||
if execMode == OpsRetryModeUpstream {
|
||||
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||
pinned = pinnedAccountID
|
||||
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||
pinned = errorLog.AccountID
|
||||
} else {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||
}
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
attemptID, err := s.opsRepo.InsertRetryAttempt(ctx, &OpsInsertRetryAttemptInput{
|
||||
RequestedByUserID: requestedByUserID,
|
||||
SourceErrorID: errorID,
|
||||
Mode: mode,
|
||||
PinnedAccountID: pinned,
|
||||
Status: opsRetryStatusRunning,
|
||||
StartedAt: startedAt,
|
||||
})
|
||||
if err != nil {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && string(pqErr.Code) == "23505" {
|
||||
return nil, infraerrors.Conflict("OPS_RETRY_IN_PROGRESS", "A retry is already in progress for this error")
|
||||
}
|
||||
return nil, infraerrors.InternalServer("OPS_RETRY_CREATE_ATTEMPT_FAILED", "Failed to create retry attempt").WithCause(err)
|
||||
}
|
||||
|
||||
result := &OpsRetryResult{
|
||||
AttemptID: attemptID,
|
||||
Mode: mode,
|
||||
Status: opsRetryStatusFailed,
|
||||
PinnedAccountID: pinned,
|
||||
HTTPStatusCode: 0,
|
||||
UpstreamRequestID: "",
|
||||
ResponsePreview: "",
|
||||
ResponseTruncated: false,
|
||||
ErrorMessage: "",
|
||||
StartedAt: startedAt,
|
||||
}
|
||||
|
||||
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
|
||||
defer cancel()
|
||||
|
||||
execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
|
||||
|
||||
finishedAt := time.Now()
|
||||
result.FinishedAt = finishedAt
|
||||
result.DurationMs = finishedAt.Sub(startedAt).Milliseconds()
|
||||
|
||||
if execRes != nil {
|
||||
result.Status = execRes.status
|
||||
result.UsedAccountID = execRes.usedAccountID
|
||||
result.HTTPStatusCode = execRes.httpStatusCode
|
||||
result.UpstreamRequestID = execRes.upstreamRequestID
|
||||
result.ResponsePreview = execRes.responsePreview
|
||||
result.ResponseTruncated = execRes.responseTruncated
|
||||
result.ErrorMessage = execRes.errorMessage
|
||||
}
|
||||
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer updateCancel()
|
||||
|
||||
var updateErrMsg *string
|
||||
if strings.TrimSpace(result.ErrorMessage) != "" {
|
||||
msg := result.ErrorMessage
|
||||
updateErrMsg = &msg
|
||||
}
|
||||
// Keep legacy result_request_id empty; use upstream_request_id instead.
|
||||
var resultRequestID *string
|
||||
|
||||
finalStatus := result.Status
|
||||
if strings.TrimSpace(finalStatus) == "" {
|
||||
finalStatus = opsRetryStatusFailed
|
||||
}
|
||||
|
||||
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
|
||||
httpStatus := result.HTTPStatusCode
|
||||
upstreamReqID := result.UpstreamRequestID
|
||||
usedAccountID := result.UsedAccountID
|
||||
preview := result.ResponsePreview
|
||||
truncated := result.ResponseTruncated
|
||||
|
||||
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
|
||||
ID: attemptID,
|
||||
Status: finalStatus,
|
||||
FinishedAt: finishedAt,
|
||||
DurationMs: result.DurationMs,
|
||||
Success: &success,
|
||||
HTTPStatusCode: &httpStatus,
|
||||
UpstreamRequestID: &upstreamReqID,
|
||||
UsedAccountID: usedAccountID,
|
||||
ResponsePreview: &preview,
|
||||
ResponseTruncated: &truncated,
|
||||
ResultRequestID: resultRequestID,
|
||||
ErrorMessage: updateErrMsg,
|
||||
}); err != nil {
|
||||
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
|
||||
} else if success {
|
||||
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
|
||||
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type opsRetryExecution struct {
|
||||
status string
|
||||
|
||||
usedAccountID *int64
|
||||
httpStatusCode int
|
||||
upstreamRequestID string
|
||||
|
||||
responsePreview string
|
||||
responseTruncated bool
|
||||
|
||||
errorMessage string
|
||||
}
|
||||
|
||||
func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDetail, mode string, pinnedAccountID *int64) *opsRetryExecution {
|
||||
if errorLog == nil {
|
||||
return &opsRetryExecution{
|
||||
status: opsRetryStatusFailed,
|
||||
errorMessage: "missing error log",
|
||||
}
|
||||
}
|
||||
|
||||
reqType := detectOpsRetryType(errorLog.RequestPath)
|
||||
bodyBytes := []byte(errorLog.RequestBody)
|
||||
|
||||
switch reqType {
|
||||
case opsRetryTypeMessages:
|
||||
bodyBytes = FilterThinkingBlocksForRetry(bodyBytes)
|
||||
case opsRetryTypeOpenAI, opsRetryTypeGeminiV1B:
|
||||
// No-op
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case OpsRetryModeUpstream:
|
||||
if pinnedAccountID == nil || *pinnedAccountID <= 0 {
|
||||
return &opsRetryExecution{
|
||||
status: opsRetryStatusFailed,
|
||||
errorMessage: "pinned_account_id required for upstream retry",
|
||||
}
|
||||
}
|
||||
return s.executePinnedRetry(ctx, reqType, errorLog, bodyBytes, *pinnedAccountID)
|
||||
case OpsRetryModeClient:
|
||||
return s.executeClientRetry(ctx, reqType, errorLog, bodyBytes)
|
||||
default:
|
||||
return &opsRetryExecution{
|
||||
status: opsRetryStatusFailed,
|
||||
errorMessage: "invalid retry mode",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func detectOpsRetryType(path string) opsRetryRequestType {
|
||||
p := strings.ToLower(strings.TrimSpace(path))
|
||||
switch {
|
||||
case strings.Contains(p, "/responses"):
|
||||
return opsRetryTypeOpenAI
|
||||
case strings.Contains(p, "/v1beta/"):
|
||||
return opsRetryTypeGeminiV1B
|
||||
default:
|
||||
return opsRetryTypeMessages
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsService) executePinnedRetry(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte, pinnedAccountID int64) *opsRetryExecution {
|
||||
if s.accountRepo == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account repository not available"}
|
||||
}
|
||||
|
||||
account, err := s.accountRepo.GetByID(ctx, pinnedAccountID)
|
||||
if err != nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: fmt.Sprintf("account not found: %v", err)}
|
||||
}
|
||||
if account == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account not found"}
|
||||
}
|
||||
if !account.IsSchedulable() {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account is not schedulable"}
|
||||
}
|
||||
if errorLog.GroupID != nil && *errorLog.GroupID > 0 {
|
||||
if !containsInt64(account.GroupIDs, *errorLog.GroupID) {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "pinned account is not in the same group as the original request"}
|
||||
}
|
||||
}
|
||||
|
||||
var release func()
|
||||
if s.concurrencyService != nil {
|
||||
acq, err := s.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: fmt.Sprintf("acquire account slot failed: %v", err)}
|
||||
}
|
||||
if acq == nil || !acq.Acquired {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account concurrency limit reached"}
|
||||
}
|
||||
release = acq.ReleaseFunc
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
usedID := account.ID
|
||||
exec := s.executeWithAccount(ctx, reqType, errorLog, body, account)
|
||||
exec.usedAccountID = &usedID
|
||||
if exec.status == "" {
|
||||
exec.status = opsRetryStatusFailed
|
||||
}
|
||||
return exec
|
||||
}
|
||||
|
||||
func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) *opsRetryExecution {
|
||||
groupID := errorLog.GroupID
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "group_id missing; cannot reselect account"}
|
||||
}
|
||||
|
||||
model, stream, parsedErr := extractRetryModelAndStream(reqType, errorLog, body)
|
||||
if parsedErr != nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: parsedErr.Error()}
|
||||
}
|
||||
_ = stream
|
||||
|
||||
excluded := make(map[int64]struct{})
|
||||
switches := 0
|
||||
|
||||
for {
|
||||
if switches >= opsRetryMaxAccountSwitches {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "retry failed after exhausting account failovers"}
|
||||
}
|
||||
|
||||
selection, selErr := s.selectAccountForRetry(ctx, reqType, groupID, model, excluded)
|
||||
if selErr != nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "no available accounts"}
|
||||
}
|
||||
|
||||
account := selection.Account
|
||||
if !selection.Acquired || selection.ReleaseFunc == nil {
|
||||
excluded[account.ID] = struct{}{}
|
||||
switches++
|
||||
continue
|
||||
}
|
||||
|
||||
exec := func() *opsRetryExecution {
|
||||
defer selection.ReleaseFunc()
|
||||
return s.executeWithAccount(ctx, reqType, errorLog, body, account)
|
||||
}()
|
||||
|
||||
if exec != nil {
|
||||
if exec.status == opsRetryStatusSucceeded {
|
||||
usedID := account.ID
|
||||
exec.usedAccountID = &usedID
|
||||
return exec
|
||||
}
|
||||
// If the gateway services ask for failover, try another account.
|
||||
if s.isFailoverError(exec.errorMessage) {
|
||||
excluded[account.ID] = struct{}{}
|
||||
switches++
|
||||
continue
|
||||
}
|
||||
usedID := account.ID
|
||||
exec.usedAccountID = &usedID
|
||||
return exec
|
||||
}
|
||||
|
||||
excluded[account.ID] = struct{}{}
|
||||
switches++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetryRequestType, groupID *int64, model string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
switch reqType {
|
||||
case opsRetryTypeOpenAI:
|
||||
if s.openAIGatewayService == nil {
|
||||
return nil, fmt.Errorf("openai gateway service not available")
|
||||
}
|
||||
return s.openAIGatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
|
||||
case opsRetryTypeGeminiV1B, opsRetryTypeMessages:
|
||||
if s.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gateway service not available")
|
||||
}
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||
}
|
||||
}
|
||||
|
||||
func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
|
||||
switch reqType {
|
||||
case opsRetryTypeMessages:
|
||||
parsed, parseErr := ParseGatewayRequest(body)
|
||||
if parseErr != nil {
|
||||
return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
|
||||
}
|
||||
return parsed.Model, parsed.Stream, nil
|
||||
case opsRetryTypeOpenAI:
|
||||
var v struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &v); err != nil {
|
||||
return "", false, fmt.Errorf("failed to parse openai request body: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(v.Model), v.Stream, nil
|
||||
case opsRetryTypeGeminiV1B:
|
||||
if strings.TrimSpace(errorLog.Model) == "" {
|
||||
return "", false, fmt.Errorf("missing model for gemini v1beta retry")
|
||||
}
|
||||
return strings.TrimSpace(errorLog.Model), errorLog.Stream, nil
|
||||
default:
|
||||
return "", false, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte, account *Account) *opsRetryExecution {
|
||||
if account == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "missing account"}
|
||||
}
|
||||
|
||||
c, w := newOpsRetryContext(ctx, errorLog)
|
||||
|
||||
var err error
|
||||
switch reqType {
|
||||
case opsRetryTypeOpenAI:
|
||||
if s.openAIGatewayService == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "openai gateway service not available"}
|
||||
}
|
||||
_, err = s.openAIGatewayService.Forward(ctx, c, account, body)
|
||||
case opsRetryTypeGeminiV1B:
|
||||
if s.geminiCompatService == nil || s.antigravityGatewayService == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini services not available"}
|
||||
}
|
||||
modelName := strings.TrimSpace(errorLog.Model)
|
||||
action := "generateContent"
|
||||
if errorLog.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
if account.Platform == PlatformAntigravity {
|
||||
_, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body)
|
||||
} else {
|
||||
_, err = s.geminiCompatService.ForwardNative(ctx, c, account, modelName, action, errorLog.Stream, body)
|
||||
}
|
||||
case opsRetryTypeMessages:
|
||||
switch account.Platform {
|
||||
case PlatformAntigravity:
|
||||
if s.antigravityGatewayService == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "antigravity gateway service not available"}
|
||||
}
|
||||
_, err = s.antigravityGatewayService.Forward(ctx, c, account, body)
|
||||
case PlatformGemini:
|
||||
if s.geminiCompatService == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini gateway service not available"}
|
||||
}
|
||||
_, err = s.geminiCompatService.Forward(ctx, c, account, body)
|
||||
default:
|
||||
if s.gatewayService == nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
|
||||
}
|
||||
parsedReq, parseErr := ParseGatewayRequest(body)
|
||||
if parseErr != nil {
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
|
||||
}
|
||||
_, err = s.gatewayService.Forward(ctx, c, account, parsedReq)
|
||||
}
|
||||
default:
|
||||
return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "unsupported retry type"}
|
||||
}
|
||||
|
||||
statusCode := http.StatusOK
|
||||
if c != nil && c.Writer != nil {
|
||||
statusCode = c.Writer.Status()
|
||||
}
|
||||
|
||||
upstreamReqID := extractUpstreamRequestID(c)
|
||||
preview, truncated := extractResponsePreview(w)
|
||||
|
||||
exec := &opsRetryExecution{
|
||||
status: opsRetryStatusFailed,
|
||||
httpStatusCode: statusCode,
|
||||
upstreamRequestID: upstreamReqID,
|
||||
responsePreview: preview,
|
||||
responseTruncated: truncated,
|
||||
errorMessage: "",
|
||||
}
|
||||
|
||||
if err == nil && statusCode < 400 {
|
||||
exec.status = opsRetryStatusSucceeded
|
||||
return exec
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
exec.errorMessage = err.Error()
|
||||
} else {
|
||||
exec.errorMessage = fmt.Sprintf("upstream returned status %d", statusCode)
|
||||
}
|
||||
|
||||
return exec
|
||||
}
|
||||
|
||||
func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.Context, *limitedResponseWriter) {
|
||||
w := newLimitedResponseWriter(opsRetryCaptureBytesLimit)
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
path := "/"
|
||||
if errorLog != nil && strings.TrimSpace(errorLog.RequestPath) != "" {
|
||||
path = errorLog.RequestPath
|
||||
}
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost"+path, bytes.NewReader(nil))
|
||||
req.Header.Set("content-type", "application/json")
|
||||
if errorLog != nil && strings.TrimSpace(errorLog.UserAgent) != "" {
|
||||
req.Header.Set("user-agent", errorLog.UserAgent)
|
||||
}
|
||||
// Restore a minimal, whitelisted subset of request headers to improve retry fidelity
|
||||
// (e.g. anthropic-beta / anthropic-version). Never replay auth credentials.
|
||||
if errorLog != nil && strings.TrimSpace(errorLog.RequestHeaders) != "" {
|
||||
var stored map[string]string
|
||||
if err := json.Unmarshal([]byte(errorLog.RequestHeaders), &stored); err == nil {
|
||||
for k, v := range stored {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if !opsRetryRequestHeaderAllowlist[strings.ToLower(key)] {
|
||||
continue
|
||||
}
|
||||
val := strings.TrimSpace(v)
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
return c, w
|
||||
}
|
||||
|
||||
func extractUpstreamRequestID(c *gin.Context) string {
|
||||
if c == nil || c.Writer == nil {
|
||||
return ""
|
||||
}
|
||||
h := c.Writer.Header()
|
||||
if h == nil {
|
||||
return ""
|
||||
}
|
||||
for _, key := range []string{"x-request-id", "X-Request-Id", "X-Request-ID"} {
|
||||
if v := strings.TrimSpace(h.Get(key)); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractResponsePreview(w *limitedResponseWriter) (preview string, truncated bool) {
|
||||
if w == nil {
|
||||
return "", false
|
||||
}
|
||||
b := bytes.TrimSpace(w.bodyBytes())
|
||||
if len(b) == 0 {
|
||||
return "", w.truncated()
|
||||
}
|
||||
if len(b) > opsRetryResponsePreviewMax {
|
||||
return string(b[:opsRetryResponsePreviewMax]), true
|
||||
}
|
||||
return string(b), w.truncated()
|
||||
}
|
||||
|
||||
func containsInt64(items []int64, needle int64) bool {
|
||||
for _, v := range items {
|
||||
if v == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpsService) isFailoverError(message string) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(message))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(msg, "upstream error:") && strings.Contains(msg, "failover")
|
||||
}
|
||||
721
backend/internal/service/ops_scheduled_report_service.go
Normal file
721
backend/internal/service/ops_scheduled_report_service.go
Normal file
@@ -0,0 +1,721 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
opsScheduledReportJobName = "ops_scheduled_reports"
|
||||
|
||||
opsScheduledReportLeaderLockKeyDefault = "ops:scheduled_reports:leader"
|
||||
opsScheduledReportLeaderLockTTLDefault = 5 * time.Minute
|
||||
|
||||
opsScheduledReportLastRunKeyPrefix = "ops:scheduled_reports:last_run:"
|
||||
|
||||
opsScheduledReportTickInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
var opsScheduledReportCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
|
||||
var opsScheduledReportReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
type OpsScheduledReportService struct {
|
||||
opsService *OpsService
|
||||
userService *UserService
|
||||
emailService *EmailService
|
||||
redisClient *redis.Client
|
||||
cfg *config.Config
|
||||
|
||||
instanceID string
|
||||
loc *time.Location
|
||||
|
||||
distributedLockOn bool
|
||||
warnNoRedisOnce sync.Once
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
stopCtx context.Context
|
||||
stop context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewOpsScheduledReportService(
|
||||
opsService *OpsService,
|
||||
userService *UserService,
|
||||
emailService *EmailService,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsScheduledReportService {
|
||||
lockOn := cfg == nil || strings.TrimSpace(cfg.RunMode) != config.RunModeSimple
|
||||
|
||||
loc := time.Local
|
||||
if cfg != nil && strings.TrimSpace(cfg.Timezone) != "" {
|
||||
if parsed, err := time.LoadLocation(strings.TrimSpace(cfg.Timezone)); err == nil && parsed != nil {
|
||||
loc = parsed
|
||||
}
|
||||
}
|
||||
return &OpsScheduledReportService{
|
||||
opsService: opsService,
|
||||
userService: userService,
|
||||
emailService: emailService,
|
||||
redisClient: redisClient,
|
||||
cfg: cfg,
|
||||
|
||||
instanceID: uuid.NewString(),
|
||||
loc: loc,
|
||||
distributedLockOn: lockOn,
|
||||
warnNoRedisOnce: sync.Once{},
|
||||
startOnce: sync.Once{},
|
||||
stopOnce: sync.Once{},
|
||||
stopCtx: nil,
|
||||
stop: nil,
|
||||
wg: sync.WaitGroup{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) Start() {
|
||||
s.StartWithContext(context.Background())
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) StartWithContext(ctx context.Context) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if s.opsService == nil || s.emailService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
s.stopCtx, s.stop = context.WithCancel(ctx)
|
||||
s.wg.Add(1)
|
||||
go s.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stop != nil {
|
||||
s.stop()
|
||||
}
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) run() {
|
||||
defer s.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(opsScheduledReportTickInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) runOnce() {
|
||||
if s == nil || s.opsService == nil || s.emailService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
ctx, cancel := context.WithTimeout(s.stopCtx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Respect ops monitoring enabled switch.
|
||||
if !s.opsService.IsMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if s.loc != nil {
|
||||
now = now.In(s.loc)
|
||||
}
|
||||
|
||||
reports := s.listScheduledReports(ctx, now)
|
||||
if len(reports) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
reportsTotal := len(reports)
|
||||
reportsDue := 0
|
||||
sentAttempts := 0
|
||||
|
||||
for _, report := range reports {
|
||||
if report == nil || !report.Enabled {
|
||||
continue
|
||||
}
|
||||
if report.NextRunAt.After(now) {
|
||||
continue
|
||||
}
|
||||
reportsDue++
|
||||
|
||||
attempts, err := s.runReport(ctx, report, now)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
return
|
||||
}
|
||||
sentAttempts += attempts
|
||||
}
|
||||
|
||||
result := truncateString(fmt.Sprintf("reports=%d due=%d send_attempts=%d", reportsTotal, reportsDue, sentAttempts), 2048)
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
|
||||
}
|
||||
|
||||
type opsScheduledReport struct {
|
||||
Name string
|
||||
ReportType string
|
||||
Schedule string
|
||||
Enabled bool
|
||||
|
||||
TimeRange time.Duration
|
||||
|
||||
Recipients []string
|
||||
|
||||
ErrorDigestMinCount int
|
||||
AccountHealthErrorRateThreshold float64
|
||||
|
||||
LastRunAt *time.Time
|
||||
NextRunAt time.Time
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, now time.Time) []*opsScheduledReport {
|
||||
if s == nil || s.opsService == nil {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
||||
if err != nil || emailCfg == nil {
|
||||
return nil
|
||||
}
|
||||
if !emailCfg.Report.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
recipients := normalizeEmails(emailCfg.Report.Recipients)
|
||||
|
||||
type reportDef struct {
|
||||
enabled bool
|
||||
name string
|
||||
kind string
|
||||
timeRange time.Duration
|
||||
schedule string
|
||||
}
|
||||
|
||||
defs := []reportDef{
|
||||
{enabled: emailCfg.Report.DailySummaryEnabled, name: "日报", kind: "daily_summary", timeRange: 24 * time.Hour, schedule: emailCfg.Report.DailySummarySchedule},
|
||||
{enabled: emailCfg.Report.WeeklySummaryEnabled, name: "周报", kind: "weekly_summary", timeRange: 7 * 24 * time.Hour, schedule: emailCfg.Report.WeeklySummarySchedule},
|
||||
{enabled: emailCfg.Report.ErrorDigestEnabled, name: "错误摘要", kind: "error_digest", timeRange: 24 * time.Hour, schedule: emailCfg.Report.ErrorDigestSchedule},
|
||||
{enabled: emailCfg.Report.AccountHealthEnabled, name: "账号健康", kind: "account_health", timeRange: 24 * time.Hour, schedule: emailCfg.Report.AccountHealthSchedule},
|
||||
}
|
||||
|
||||
out := make([]*opsScheduledReport, 0, len(defs))
|
||||
for _, d := range defs {
|
||||
if !d.enabled {
|
||||
continue
|
||||
}
|
||||
spec := strings.TrimSpace(d.schedule)
|
||||
if spec == "" {
|
||||
continue
|
||||
}
|
||||
sched, err := opsScheduledReportCronParser.Parse(spec)
|
||||
if err != nil {
|
||||
log.Printf("[OpsScheduledReport] invalid cron spec=%q for report=%s: %v", spec, d.kind, err)
|
||||
continue
|
||||
}
|
||||
|
||||
lastRun := s.getLastRunAt(ctx, d.kind)
|
||||
base := lastRun
|
||||
if base.IsZero() {
|
||||
// Allow a schedule matching the current minute to trigger right after startup.
|
||||
base = now.Add(-1 * time.Minute)
|
||||
}
|
||||
next := sched.Next(base)
|
||||
if next.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
var lastRunPtr *time.Time
|
||||
if !lastRun.IsZero() {
|
||||
lastCopy := lastRun
|
||||
lastRunPtr = &lastCopy
|
||||
}
|
||||
|
||||
out = append(out, &opsScheduledReport{
|
||||
Name: d.name,
|
||||
ReportType: d.kind,
|
||||
Schedule: spec,
|
||||
Enabled: true,
|
||||
|
||||
TimeRange: d.timeRange,
|
||||
|
||||
Recipients: recipients,
|
||||
|
||||
ErrorDigestMinCount: emailCfg.Report.ErrorDigestMinCount,
|
||||
AccountHealthErrorRateThreshold: emailCfg.Report.AccountHealthErrorRateThreshold,
|
||||
|
||||
LastRunAt: lastRunPtr,
|
||||
NextRunAt: next,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) (int, error) {
|
||||
if s == nil || s.opsService == nil || s.emailService == nil || report == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Mark as "run" up-front so a broken SMTP config doesn't spam retries every minute.
|
||||
s.setLastRunAt(ctx, report.ReportType, now)
|
||||
|
||||
content, err := s.generateReportHTML(ctx, report, now)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
// Skip sending when the report decides not to emit content (e.g., digest below min count).
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
recipients := report.Recipients
|
||||
if len(recipients) == 0 && s.userService != nil {
|
||||
admin, err := s.userService.GetFirstAdmin(ctx)
|
||||
if err == nil && admin != nil && strings.TrimSpace(admin.Email) != "" {
|
||||
recipients = []string{strings.TrimSpace(admin.Email)}
|
||||
}
|
||||
}
|
||||
if len(recipients) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
|
||||
|
||||
attempts := 0
|
||||
for _, to := range recipients {
|
||||
addr := strings.TrimSpace(to)
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
attempts++
|
||||
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
|
||||
// Ignore per-recipient failures; continue best-effort.
|
||||
continue
|
||||
}
|
||||
}
|
||||
return attempts, nil
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
|
||||
if s == nil || s.opsService == nil || report == nil {
|
||||
return "", fmt.Errorf("service not initialized")
|
||||
}
|
||||
if report.TimeRange <= 0 {
|
||||
return "", fmt.Errorf("invalid time range")
|
||||
}
|
||||
|
||||
end := now.UTC()
|
||||
start := end.Add(-report.TimeRange)
|
||||
|
||||
switch strings.TrimSpace(report.ReportType) {
|
||||
case "daily_summary", "weekly_summary":
|
||||
overview, err := s.opsService.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: "",
|
||||
GroupID: nil,
|
||||
QueryMode: OpsQueryModeAuto,
|
||||
})
|
||||
if err != nil {
|
||||
// If pre-aggregation isn't ready but the report is requested, fall back to raw.
|
||||
if strings.TrimSpace(report.ReportType) == "daily_summary" || strings.TrimSpace(report.ReportType) == "weekly_summary" {
|
||||
overview, err = s.opsService.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: "",
|
||||
GroupID: nil,
|
||||
QueryMode: OpsQueryModeRaw,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return buildOpsSummaryEmailHTML(report.Name, start, end, overview), nil
|
||||
case "error_digest":
|
||||
// Lightweight digest: list recent errors (status>=400) and breakdown by type.
|
||||
startTime := start
|
||||
endTime := end
|
||||
filter := &OpsErrorLogFilter{
|
||||
StartTime: &startTime,
|
||||
EndTime: &endTime,
|
||||
Page: 1,
|
||||
PageSize: 100,
|
||||
}
|
||||
out, err := s.opsService.GetErrorLogs(ctx, filter)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if report.ErrorDigestMinCount > 0 && out != nil && out.Total < report.ErrorDigestMinCount {
|
||||
return "", nil
|
||||
}
|
||||
return buildOpsErrorDigestEmailHTML(report.Name, start, end, out), nil
|
||||
case "account_health":
|
||||
// Best-effort: use account availability (not error rate yet).
|
||||
avail, err := s.opsService.GetAccountAvailability(ctx, "", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
_ = report.AccountHealthErrorRateThreshold // reserved for future per-account error rate report
|
||||
return buildOpsAccountHealthEmailHTML(report.Name, start, end, avail), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unknown report type: %s", report.ReportType)
|
||||
}
|
||||
}
|
||||
|
||||
func buildOpsSummaryEmailHTML(title string, start, end time.Time, overview *OpsDashboardOverview) string {
|
||||
if overview == nil {
|
||||
return fmt.Sprintf("<h2>%s</h2><p>No data.</p>", htmlEscape(title))
|
||||
}
|
||||
|
||||
latP50 := "-"
|
||||
latP99 := "-"
|
||||
if overview.Duration.P50 != nil {
|
||||
latP50 = fmt.Sprintf("%dms", *overview.Duration.P50)
|
||||
}
|
||||
if overview.Duration.P99 != nil {
|
||||
latP99 = fmt.Sprintf("%dms", *overview.Duration.P99)
|
||||
}
|
||||
|
||||
ttftP50 := "-"
|
||||
ttftP99 := "-"
|
||||
if overview.TTFT.P50 != nil {
|
||||
ttftP50 = fmt.Sprintf("%dms", *overview.TTFT.P50)
|
||||
}
|
||||
if overview.TTFT.P99 != nil {
|
||||
ttftP99 = fmt.Sprintf("%dms", *overview.TTFT.P99)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
<h2>%s</h2>
|
||||
<p><b>Period</b>: %s ~ %s (UTC)</p>
|
||||
<ul>
|
||||
<li><b>Total Requests</b>: %d</li>
|
||||
<li><b>Success</b>: %d</li>
|
||||
<li><b>Errors (SLA)</b>: %d</li>
|
||||
<li><b>Business Limited</b>: %d</li>
|
||||
<li><b>SLA</b>: %.2f%%</li>
|
||||
<li><b>Error Rate</b>: %.2f%%</li>
|
||||
<li><b>Upstream Error Rate (excl 429/529)</b>: %.2f%%</li>
|
||||
<li><b>Upstream Errors</b>: excl429/529=%d, 429=%d, 529=%d</li>
|
||||
<li><b>Latency</b>: p50=%s, p99=%s</li>
|
||||
<li><b>TTFT</b>: p50=%s, p99=%s</li>
|
||||
<li><b>Tokens</b>: %d</li>
|
||||
<li><b>QPS</b>: current=%.1f, peak=%.1f, avg=%.1f</li>
|
||||
<li><b>TPS</b>: current=%.1f, peak=%.1f, avg=%.1f</li>
|
||||
</ul>
|
||||
`,
|
||||
htmlEscape(strings.TrimSpace(title)),
|
||||
htmlEscape(start.UTC().Format(time.RFC3339)),
|
||||
htmlEscape(end.UTC().Format(time.RFC3339)),
|
||||
overview.RequestCountTotal,
|
||||
overview.SuccessCount,
|
||||
overview.ErrorCountSLA,
|
||||
overview.BusinessLimitedCount,
|
||||
overview.SLA*100,
|
||||
overview.ErrorRate*100,
|
||||
overview.UpstreamErrorRate*100,
|
||||
overview.UpstreamErrorCountExcl429529,
|
||||
overview.Upstream429Count,
|
||||
overview.Upstream529Count,
|
||||
htmlEscape(latP50),
|
||||
htmlEscape(latP99),
|
||||
htmlEscape(ttftP50),
|
||||
htmlEscape(ttftP99),
|
||||
overview.TokenConsumed,
|
||||
overview.QPS.Current,
|
||||
overview.QPS.Peak,
|
||||
overview.QPS.Avg,
|
||||
overview.TPS.Current,
|
||||
overview.TPS.Peak,
|
||||
overview.TPS.Avg,
|
||||
)
|
||||
}
|
||||
|
||||
func buildOpsErrorDigestEmailHTML(title string, start, end time.Time, list *OpsErrorLogList) string {
|
||||
total := 0
|
||||
recent := []*OpsErrorLog{}
|
||||
if list != nil {
|
||||
total = list.Total
|
||||
recent = list.Errors
|
||||
}
|
||||
if len(recent) > 10 {
|
||||
recent = recent[:10]
|
||||
}
|
||||
|
||||
rows := ""
|
||||
for _, item := range recent {
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
rows += fmt.Sprintf(
|
||||
"<tr><td>%s</td><td>%s</td><td>%d</td><td>%s</td></tr>",
|
||||
htmlEscape(item.CreatedAt.UTC().Format(time.RFC3339)),
|
||||
htmlEscape(item.Platform),
|
||||
item.StatusCode,
|
||||
htmlEscape(truncateString(item.Message, 180)),
|
||||
)
|
||||
}
|
||||
if rows == "" {
|
||||
rows = "<tr><td colspan=\"4\">No recent errors.</td></tr>"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
<h2>%s</h2>
|
||||
<p><b>Period</b>: %s ~ %s (UTC)</p>
|
||||
<p><b>Total Errors</b>: %d</p>
|
||||
<h3>Recent</h3>
|
||||
<table border="1" cellpadding="6" cellspacing="0" style="border-collapse:collapse;">
|
||||
<thead><tr><th>Time</th><th>Platform</th><th>Status</th><th>Message</th></tr></thead>
|
||||
<tbody>%s</tbody>
|
||||
</table>
|
||||
`,
|
||||
htmlEscape(strings.TrimSpace(title)),
|
||||
htmlEscape(start.UTC().Format(time.RFC3339)),
|
||||
htmlEscape(end.UTC().Format(time.RFC3339)),
|
||||
total,
|
||||
rows,
|
||||
)
|
||||
}
|
||||
|
||||
func buildOpsAccountHealthEmailHTML(title string, start, end time.Time, avail *OpsAccountAvailability) string {
|
||||
total := 0
|
||||
available := 0
|
||||
rateLimited := 0
|
||||
hasError := 0
|
||||
|
||||
if avail != nil && avail.Accounts != nil {
|
||||
for _, a := range avail.Accounts {
|
||||
if a == nil {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
if a.IsAvailable {
|
||||
available++
|
||||
}
|
||||
if a.IsRateLimited {
|
||||
rateLimited++
|
||||
}
|
||||
if a.HasError {
|
||||
hasError++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
<h2>%s</h2>
|
||||
<p><b>Period</b>: %s ~ %s (UTC)</p>
|
||||
<ul>
|
||||
<li><b>Total Accounts</b>: %d</li>
|
||||
<li><b>Available</b>: %d</li>
|
||||
<li><b>Rate Limited</b>: %d</li>
|
||||
<li><b>Error</b>: %d</li>
|
||||
</ul>
|
||||
<p>Note: This report currently reflects account availability status only.</p>
|
||||
`,
|
||||
htmlEscape(strings.TrimSpace(title)),
|
||||
htmlEscape(start.UTC().Format(time.RFC3339)),
|
||||
htmlEscape(end.UTC().Format(time.RFC3339)),
|
||||
total,
|
||||
available,
|
||||
rateLimited,
|
||||
hasError,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||
if s == nil || !s.distributedLockOn {
|
||||
return nil, true
|
||||
}
|
||||
if s.redisClient == nil {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsScheduledReport] redis not configured; running without distributed lock")
|
||||
})
|
||||
return nil, true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
key := opsScheduledReportLeaderLockKeyDefault
|
||||
ttl := opsScheduledReportLeaderLockTTLDefault
|
||||
if strings.TrimSpace(key) == "" {
|
||||
key = "ops:scheduled_reports:leader"
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = 5 * time.Minute
|
||||
}
|
||||
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err != nil {
|
||||
// Prefer fail-closed to avoid duplicate report sends when Redis is flaky.
|
||||
log.Printf("[OpsScheduledReport] leader lock SetNX failed; skipping this cycle: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return func() {
|
||||
_, _ = opsScheduledReportReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) getLastRunAt(ctx context.Context, reportType string) time.Time {
|
||||
if s == nil || s.redisClient == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
kind := strings.TrimSpace(reportType)
|
||||
if kind == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
key := opsScheduledReportLastRunKeyPrefix + kind
|
||||
|
||||
raw, err := s.redisClient.Get(ctx, key).Result()
|
||||
if err != nil || strings.TrimSpace(raw) == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
sec, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
|
||||
if err != nil || sec <= 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
last := time.Unix(sec, 0)
|
||||
// Cron schedules are interpreted in the configured timezone (s.loc). Ensure the base time
|
||||
// passed into cron.Next() uses the same location; otherwise the job will drift by timezone
|
||||
// offset (e.g. Asia/Shanghai default would run 8h later after the first execution).
|
||||
if s.loc != nil {
|
||||
return last.In(s.loc)
|
||||
}
|
||||
return last.UTC()
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType string, t time.Time) {
|
||||
if s == nil || s.redisClient == nil {
|
||||
return
|
||||
}
|
||||
kind := strings.TrimSpace(reportType)
|
||||
if kind == "" {
|
||||
return
|
||||
}
|
||||
if t.IsZero() {
|
||||
t = time.Now().UTC()
|
||||
}
|
||||
key := opsScheduledReportLastRunKeyPrefix + kind
|
||||
_ = s.redisClient.Set(ctx, key, strconv.FormatInt(t.UTC().Unix(), 10), 14*24*time.Hour).Err()
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
|
||||
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
msg := strings.TrimSpace(result)
|
||||
if msg == "" {
|
||||
msg = "ok"
|
||||
}
|
||||
msg = truncateString(msg, 2048)
|
||||
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsScheduledReportJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
LastResult: &msg,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
|
||||
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil || err == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsScheduledReportJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &now,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeEmails(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(in))
|
||||
out := make([]string, 0, len(in))
|
||||
for _, raw := range in {
|
||||
addr := strings.ToLower(strings.TrimSpace(raw))
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
613
backend/internal/service/ops_service.go
Normal file
613
backend/internal/service/ops_service.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
|
||||
|
||||
const (
|
||||
opsMaxStoredRequestBodyBytes = 10 * 1024
|
||||
opsMaxStoredErrorBodyBytes = 20 * 1024
|
||||
)
|
||||
|
||||
// OpsService provides ingestion and query APIs for the Ops monitoring module.
|
||||
type OpsService struct {
|
||||
opsRepo OpsRepository
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
|
||||
accountRepo AccountRepository
|
||||
|
||||
// getAccountAvailability is a unit-test hook for overriding account availability lookup.
|
||||
getAccountAvailability func(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error)
|
||||
|
||||
concurrencyService *ConcurrencyService
|
||||
gatewayService *GatewayService
|
||||
openAIGatewayService *OpenAIGatewayService
|
||||
geminiCompatService *GeminiMessagesCompatService
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
}
|
||||
|
||||
func NewOpsService(
|
||||
opsRepo OpsRepository,
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
accountRepo AccountRepository,
|
||||
concurrencyService *ConcurrencyService,
|
||||
gatewayService *GatewayService,
|
||||
openAIGatewayService *OpenAIGatewayService,
|
||||
geminiCompatService *GeminiMessagesCompatService,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
) *OpsService {
|
||||
return &OpsService{
|
||||
opsRepo: opsRepo,
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
|
||||
accountRepo: accountRepo,
|
||||
|
||||
concurrencyService: concurrencyService,
|
||||
gatewayService: gatewayService,
|
||||
openAIGatewayService: openAIGatewayService,
|
||||
geminiCompatService: geminiCompatService,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error {
|
||||
if s.IsMonitoringEnabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
return ErrOpsDisabled
|
||||
}
|
||||
|
||||
func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
|
||||
// Hard switch: disable ops entirely.
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return false
|
||||
}
|
||||
if s.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
|
||||
if err != nil {
|
||||
// Default enabled when key is missing, and fail-open on transient errors
|
||||
// (ops should never block gateway traffic).
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
if !s.IsMonitoringEnabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure timestamps are always populated.
|
||||
if entry.CreatedAt.IsZero() {
|
||||
entry.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
// Ensure required fields exist (DB has NOT NULL constraints).
|
||||
entry.ErrorPhase = strings.TrimSpace(entry.ErrorPhase)
|
||||
entry.ErrorType = strings.TrimSpace(entry.ErrorType)
|
||||
if entry.ErrorPhase == "" {
|
||||
entry.ErrorPhase = "internal"
|
||||
}
|
||||
if entry.ErrorType == "" {
|
||||
entry.ErrorType = "api_error"
|
||||
}
|
||||
|
||||
// Sanitize + trim request body (errors only).
|
||||
if len(rawRequestBody) > 0 {
|
||||
sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes)
|
||||
if sanitized != "" {
|
||||
entry.RequestBodyJSON = &sanitized
|
||||
}
|
||||
entry.RequestBodyTruncated = truncated
|
||||
entry.RequestBodyBytes = &bytesLen
|
||||
}
|
||||
|
||||
// Sanitize + truncate error_body to avoid storing sensitive data.
|
||||
if strings.TrimSpace(entry.ErrorBody) != "" {
|
||||
sanitized, _ := sanitizeErrorBodyForStorage(entry.ErrorBody, opsMaxStoredErrorBodyBytes)
|
||||
entry.ErrorBody = sanitized
|
||||
}
|
||||
|
||||
// Sanitize upstream error context if provided by gateway services.
|
||||
if entry.UpstreamStatusCode != nil && *entry.UpstreamStatusCode <= 0 {
|
||||
entry.UpstreamStatusCode = nil
|
||||
}
|
||||
if entry.UpstreamErrorMessage != nil {
|
||||
msg := strings.TrimSpace(*entry.UpstreamErrorMessage)
|
||||
msg = sanitizeUpstreamErrorMessage(msg)
|
||||
msg = truncateString(msg, 2048)
|
||||
if strings.TrimSpace(msg) == "" {
|
||||
entry.UpstreamErrorMessage = nil
|
||||
} else {
|
||||
entry.UpstreamErrorMessage = &msg
|
||||
}
|
||||
}
|
||||
if entry.UpstreamErrorDetail != nil {
|
||||
detail := strings.TrimSpace(*entry.UpstreamErrorDetail)
|
||||
if detail == "" {
|
||||
entry.UpstreamErrorDetail = nil
|
||||
} else {
|
||||
sanitized, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
|
||||
if strings.TrimSpace(sanitized) == "" {
|
||||
entry.UpstreamErrorDetail = nil
|
||||
} else {
|
||||
entry.UpstreamErrorDetail = &sanitized
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize + serialize upstream error events list.
|
||||
if len(entry.UpstreamErrors) > 0 {
|
||||
const maxEvents = 32
|
||||
events := entry.UpstreamErrors
|
||||
if len(events) > maxEvents {
|
||||
events = events[len(events)-maxEvents:]
|
||||
}
|
||||
|
||||
sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events))
|
||||
for _, ev := range events {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
out := *ev
|
||||
|
||||
out.Platform = strings.TrimSpace(out.Platform)
|
||||
out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128)
|
||||
out.Kind = truncateString(strings.TrimSpace(out.Kind), 64)
|
||||
|
||||
if out.AccountID < 0 {
|
||||
out.AccountID = 0
|
||||
}
|
||||
if out.UpstreamStatusCode < 0 {
|
||||
out.UpstreamStatusCode = 0
|
||||
}
|
||||
if out.AtUnixMs < 0 {
|
||||
out.AtUnixMs = 0
|
||||
}
|
||||
|
||||
msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message))
|
||||
msg = truncateString(msg, 2048)
|
||||
out.Message = msg
|
||||
|
||||
detail := strings.TrimSpace(out.Detail)
|
||||
if detail != "" {
|
||||
// Keep upstream detail small; request bodies are not stored here, only upstream error payloads.
|
||||
sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes)
|
||||
out.Detail = sanitizedDetail
|
||||
} else {
|
||||
out.Detail = ""
|
||||
}
|
||||
|
||||
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
|
||||
if out.UpstreamRequestBody != "" {
|
||||
// Reuse the same sanitization/trimming strategy as request body storage.
|
||||
// Keep it small so it is safe to persist in ops_error_logs JSON.
|
||||
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
|
||||
if sanitized != "" {
|
||||
out.UpstreamRequestBody = sanitized
|
||||
if truncated {
|
||||
out.Kind = strings.TrimSpace(out.Kind)
|
||||
if out.Kind == "" {
|
||||
out.Kind = "upstream"
|
||||
}
|
||||
out.Kind = out.Kind + ":request_body_truncated"
|
||||
}
|
||||
} else {
|
||||
out.UpstreamRequestBody = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Drop fully-empty events (can happen if only status code was known).
|
||||
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
evCopy := out
|
||||
sanitized = append(sanitized, &evCopy)
|
||||
}
|
||||
|
||||
entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized)
|
||||
entry.UpstreamErrors = nil
|
||||
}
|
||||
|
||||
if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil {
|
||||
// Never bubble up to gateway; best-effort logging.
|
||||
log.Printf("[Ops] RecordError failed: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
|
||||
}
|
||||
result, err := s.opsRepo.ListErrorLogs(ctx, filter)
|
||||
if err != nil {
|
||||
log.Printf("[Ops] GetErrorLogs failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
detail, err := s.opsRepo.GetErrorLogByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
return nil, infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
|
||||
}
|
||||
return detail, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if errorID <= 0 {
|
||||
return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
|
||||
}
|
||||
items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return []*OpsRetryAttempt{}, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if errorID <= 0 {
|
||||
return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
|
||||
}
|
||||
// Best-effort ensure the error exists
|
||||
if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
|
||||
}
|
||||
return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
|
||||
}
|
||||
|
||||
func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
|
||||
bytesLen = len(raw)
|
||||
if len(raw) == 0 {
|
||||
return "", false, 0
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
// If it's not valid JSON, don't store (retry would not be reliable anyway).
|
||||
return "", false, bytesLen
|
||||
}
|
||||
|
||||
decoded = redactSensitiveJSON(decoded)
|
||||
|
||||
encoded, err := json.Marshal(decoded)
|
||||
if err != nil {
|
||||
return "", false, bytesLen
|
||||
}
|
||||
if len(encoded) <= maxBytes {
|
||||
return string(encoded), false, bytesLen
|
||||
}
|
||||
|
||||
// Trim conversation history to keep the most recent context.
|
||||
if root, ok := decoded.(map[string]any); ok {
|
||||
if trimmed, ok := trimConversationArrays(root, maxBytes); ok {
|
||||
encoded2, err2 := json.Marshal(trimmed)
|
||||
if err2 == nil && len(encoded2) <= maxBytes {
|
||||
return string(encoded2), true, bytesLen
|
||||
}
|
||||
// Fallthrough: keep shrinking.
|
||||
decoded = trimmed
|
||||
}
|
||||
|
||||
essential := shrinkToEssentials(root)
|
||||
encoded3, err3 := json.Marshal(essential)
|
||||
if err3 == nil && len(encoded3) <= maxBytes {
|
||||
return string(encoded3), true, bytesLen
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: keep JSON shape but drop big fields.
|
||||
// This avoids downstream code that expects certain top-level keys from crashing.
|
||||
if root, ok := decoded.(map[string]any); ok {
|
||||
placeholder := shallowCopyMap(root)
|
||||
placeholder["request_body_truncated"] = true
|
||||
|
||||
// Replace potentially huge arrays/strings, but keep the keys present.
|
||||
for _, k := range []string{"messages", "contents", "input", "prompt"} {
|
||||
if _, exists := placeholder[k]; exists {
|
||||
placeholder[k] = []any{}
|
||||
}
|
||||
}
|
||||
for _, k := range []string{"text"} {
|
||||
if _, exists := placeholder[k]; exists {
|
||||
placeholder[k] = ""
|
||||
}
|
||||
}
|
||||
|
||||
encoded4, err4 := json.Marshal(placeholder)
|
||||
if err4 == nil {
|
||||
if len(encoded4) <= maxBytes {
|
||||
return string(encoded4), true, bytesLen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final fallback: minimal valid JSON.
|
||||
encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
|
||||
if err4 != nil {
|
||||
return "", true, bytesLen
|
||||
}
|
||||
return string(encoded4), true, bytesLen
|
||||
}
|
||||
|
||||
func redactSensitiveJSON(v any) any {
|
||||
switch t := v.(type) {
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(t))
|
||||
for k, vv := range t {
|
||||
if isSensitiveKey(k) {
|
||||
out[k] = "[REDACTED]"
|
||||
continue
|
||||
}
|
||||
out[k] = redactSensitiveJSON(vv)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, 0, len(t))
|
||||
for _, vv := range t {
|
||||
out = append(out, redactSensitiveJSON(vv))
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func isSensitiveKey(key string) bool {
|
||||
k := strings.ToLower(strings.TrimSpace(key))
|
||||
if k == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Exact matches (common credential fields).
|
||||
switch k {
|
||||
case "authorization",
|
||||
"proxy-authorization",
|
||||
"x-api-key",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"id_token",
|
||||
"session_token",
|
||||
"token",
|
||||
"password",
|
||||
"passwd",
|
||||
"passphrase",
|
||||
"secret",
|
||||
"client_secret",
|
||||
"private_key",
|
||||
"jwt",
|
||||
"signature",
|
||||
"accesskeyid",
|
||||
"secretaccesskey":
|
||||
return true
|
||||
}
|
||||
|
||||
// Suffix matches.
|
||||
for _, suffix := range []string{
|
||||
"_secret",
|
||||
"_token",
|
||||
"_id_token",
|
||||
"_session_token",
|
||||
"_password",
|
||||
"_passwd",
|
||||
"_passphrase",
|
||||
"_key",
|
||||
"secret_key",
|
||||
"private_key",
|
||||
} {
|
||||
if strings.HasSuffix(k, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Substring matches (conservative, but errs on the side of privacy).
|
||||
for _, sub := range []string{
|
||||
"secret",
|
||||
"token",
|
||||
"password",
|
||||
"passwd",
|
||||
"passphrase",
|
||||
"privatekey",
|
||||
"private_key",
|
||||
"apikey",
|
||||
"api_key",
|
||||
"accesskeyid",
|
||||
"secretaccesskey",
|
||||
"bearer",
|
||||
"cookie",
|
||||
"credential",
|
||||
"session",
|
||||
"jwt",
|
||||
"signature",
|
||||
} {
|
||||
if strings.Contains(k, sub) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func trimConversationArrays(root map[string]any, maxBytes int) (map[string]any, bool) {
|
||||
// Supported: anthropic/openai: messages; gemini: contents.
|
||||
if out, ok := trimArrayField(root, "messages", maxBytes); ok {
|
||||
return out, true
|
||||
}
|
||||
if out, ok := trimArrayField(root, "contents", maxBytes); ok {
|
||||
return out, true
|
||||
}
|
||||
return root, false
|
||||
}
|
||||
|
||||
func trimArrayField(root map[string]any, field string, maxBytes int) (map[string]any, bool) {
|
||||
raw, ok := root[field]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
arr, ok := raw.([]any)
|
||||
if !ok || len(arr) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Keep at least the last message/content. Use binary search so we don't marshal O(n) times.
|
||||
// We are dropping from the *front* of the array (oldest context first).
|
||||
lo := 0
|
||||
hi := len(arr) - 1 // inclusive; hi ensures at least one item remains
|
||||
|
||||
var best map[string]any
|
||||
found := false
|
||||
|
||||
for lo <= hi {
|
||||
mid := (lo + hi) / 2
|
||||
candidateArr := arr[mid:]
|
||||
if len(candidateArr) == 0 {
|
||||
lo = mid + 1
|
||||
continue
|
||||
}
|
||||
|
||||
next := shallowCopyMap(root)
|
||||
next[field] = candidateArr
|
||||
encoded, err := json.Marshal(next)
|
||||
if err != nil {
|
||||
// If marshal fails, try dropping more.
|
||||
lo = mid + 1
|
||||
continue
|
||||
}
|
||||
|
||||
if len(encoded) <= maxBytes {
|
||||
best = next
|
||||
found = true
|
||||
// Try to keep more context by dropping fewer items.
|
||||
hi = mid - 1
|
||||
continue
|
||||
}
|
||||
|
||||
// Need to drop more.
|
||||
lo = mid + 1
|
||||
}
|
||||
|
||||
if found {
|
||||
return best, true
|
||||
}
|
||||
|
||||
// Nothing fit (even with only one element); return the smallest slice and let the
|
||||
// caller fall back to shrinkToEssentials().
|
||||
next := shallowCopyMap(root)
|
||||
next[field] = arr[len(arr)-1:]
|
||||
return next, true
|
||||
}
|
||||
|
||||
func shrinkToEssentials(root map[string]any) map[string]any {
|
||||
out := make(map[string]any)
|
||||
for _, key := range []string{"model", "stream", "max_tokens", "temperature", "top_p", "top_k"} {
|
||||
if v, ok := root[key]; ok {
|
||||
out[key] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Keep only the last element of the conversation array.
|
||||
if v, ok := root["messages"]; ok {
|
||||
if arr, ok := v.([]any); ok && len(arr) > 0 {
|
||||
out["messages"] = []any{arr[len(arr)-1]}
|
||||
}
|
||||
}
|
||||
if v, ok := root["contents"]; ok {
|
||||
if arr, ok := v.([]any); ok && len(arr) > 0 {
|
||||
out["contents"] = []any{arr[len(arr)-1]}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func shallowCopyMap(m map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(m))
|
||||
for k, v := range m {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, truncated bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Prefer JSON-safe sanitization when possible.
|
||||
if out, trunc, _ := sanitizeAndTrimRequestBody([]byte(raw), maxBytes); out != "" {
|
||||
return out, trunc
|
||||
}
|
||||
|
||||
// Non-JSON: best-effort truncate.
|
||||
if maxBytes > 0 && len(raw) > maxBytes {
|
||||
return truncateString(raw, maxBytes), true
|
||||
}
|
||||
return raw, false
|
||||
}
|
||||
562
backend/internal/service/ops_settings.go
Normal file
562
backend/internal/service/ops_settings.go
Normal file
@@ -0,0 +1,562 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAlertEvaluatorLeaderLockKeyDefault = "ops:alert:evaluator:leader"
|
||||
opsAlertEvaluatorLeaderLockTTLDefault = 30 * time.Second
|
||||
)
|
||||
|
||||
// =========================
|
||||
// Email notification config
|
||||
// =========================
|
||||
|
||||
func (s *OpsService) GetEmailNotificationConfig(ctx context.Context) (*OpsEmailNotificationConfig, error) {
|
||||
defaultCfg := defaultOpsEmailNotificationConfig()
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsEmailNotificationConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
// Initialize defaults on first read (best-effort).
|
||||
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
|
||||
_ = s.settingRepo.Set(ctx, SettingKeyOpsEmailNotificationConfig, string(b))
|
||||
}
|
||||
return defaultCfg, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsEmailNotificationConfig{}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
// Corrupted JSON should not break ops UI; fall back to defaults.
|
||||
return defaultCfg, nil
|
||||
}
|
||||
normalizeOpsEmailNotificationConfig(cfg)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateEmailNotificationConfig(ctx context.Context, req *OpsEmailNotificationConfigUpdateRequest) (*OpsEmailNotificationConfig, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return nil, errors.New("setting repository not initialized")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if req == nil {
|
||||
return nil, errors.New("invalid request")
|
||||
}
|
||||
|
||||
cfg, err := s.GetEmailNotificationConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.Alert != nil {
|
||||
cfg.Alert.Enabled = req.Alert.Enabled
|
||||
if req.Alert.Recipients != nil {
|
||||
cfg.Alert.Recipients = req.Alert.Recipients
|
||||
}
|
||||
cfg.Alert.MinSeverity = strings.TrimSpace(req.Alert.MinSeverity)
|
||||
cfg.Alert.RateLimitPerHour = req.Alert.RateLimitPerHour
|
||||
cfg.Alert.BatchingWindowSeconds = req.Alert.BatchingWindowSeconds
|
||||
cfg.Alert.IncludeResolvedAlerts = req.Alert.IncludeResolvedAlerts
|
||||
}
|
||||
|
||||
if req.Report != nil {
|
||||
cfg.Report.Enabled = req.Report.Enabled
|
||||
if req.Report.Recipients != nil {
|
||||
cfg.Report.Recipients = req.Report.Recipients
|
||||
}
|
||||
cfg.Report.DailySummaryEnabled = req.Report.DailySummaryEnabled
|
||||
cfg.Report.DailySummarySchedule = strings.TrimSpace(req.Report.DailySummarySchedule)
|
||||
cfg.Report.WeeklySummaryEnabled = req.Report.WeeklySummaryEnabled
|
||||
cfg.Report.WeeklySummarySchedule = strings.TrimSpace(req.Report.WeeklySummarySchedule)
|
||||
cfg.Report.ErrorDigestEnabled = req.Report.ErrorDigestEnabled
|
||||
cfg.Report.ErrorDigestSchedule = strings.TrimSpace(req.Report.ErrorDigestSchedule)
|
||||
cfg.Report.ErrorDigestMinCount = req.Report.ErrorDigestMinCount
|
||||
cfg.Report.AccountHealthEnabled = req.Report.AccountHealthEnabled
|
||||
cfg.Report.AccountHealthSchedule = strings.TrimSpace(req.Report.AccountHealthSchedule)
|
||||
cfg.Report.AccountHealthErrorRateThreshold = req.Report.AccountHealthErrorRateThreshold
|
||||
}
|
||||
|
||||
if err := validateOpsEmailNotificationConfig(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizeOpsEmailNotificationConfig(cfg)
|
||||
raw, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyOpsEmailNotificationConfig, string(raw)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func defaultOpsEmailNotificationConfig() *OpsEmailNotificationConfig {
|
||||
return &OpsEmailNotificationConfig{
|
||||
Alert: OpsEmailAlertConfig{
|
||||
Enabled: true,
|
||||
Recipients: []string{},
|
||||
MinSeverity: "",
|
||||
RateLimitPerHour: 0,
|
||||
BatchingWindowSeconds: 0,
|
||||
IncludeResolvedAlerts: false,
|
||||
},
|
||||
Report: OpsEmailReportConfig{
|
||||
Enabled: false,
|
||||
Recipients: []string{},
|
||||
DailySummaryEnabled: false,
|
||||
DailySummarySchedule: "0 9 * * *",
|
||||
WeeklySummaryEnabled: false,
|
||||
WeeklySummarySchedule: "0 9 * * 1",
|
||||
ErrorDigestEnabled: false,
|
||||
ErrorDigestSchedule: "0 9 * * *",
|
||||
ErrorDigestMinCount: 10,
|
||||
AccountHealthEnabled: false,
|
||||
AccountHealthSchedule: "0 9 * * *",
|
||||
AccountHealthErrorRateThreshold: 10.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpsEmailNotificationConfig(cfg *OpsEmailNotificationConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if cfg.Alert.Recipients == nil {
|
||||
cfg.Alert.Recipients = []string{}
|
||||
}
|
||||
if cfg.Report.Recipients == nil {
|
||||
cfg.Report.Recipients = []string{}
|
||||
}
|
||||
|
||||
cfg.Alert.MinSeverity = strings.TrimSpace(cfg.Alert.MinSeverity)
|
||||
cfg.Report.DailySummarySchedule = strings.TrimSpace(cfg.Report.DailySummarySchedule)
|
||||
cfg.Report.WeeklySummarySchedule = strings.TrimSpace(cfg.Report.WeeklySummarySchedule)
|
||||
cfg.Report.ErrorDigestSchedule = strings.TrimSpace(cfg.Report.ErrorDigestSchedule)
|
||||
cfg.Report.AccountHealthSchedule = strings.TrimSpace(cfg.Report.AccountHealthSchedule)
|
||||
|
||||
// Fill missing schedules with defaults to avoid breaking cron logic if clients send empty strings.
|
||||
if cfg.Report.DailySummarySchedule == "" {
|
||||
cfg.Report.DailySummarySchedule = "0 9 * * *"
|
||||
}
|
||||
if cfg.Report.WeeklySummarySchedule == "" {
|
||||
cfg.Report.WeeklySummarySchedule = "0 9 * * 1"
|
||||
}
|
||||
if cfg.Report.ErrorDigestSchedule == "" {
|
||||
cfg.Report.ErrorDigestSchedule = "0 9 * * *"
|
||||
}
|
||||
if cfg.Report.AccountHealthSchedule == "" {
|
||||
cfg.Report.AccountHealthSchedule = "0 9 * * *"
|
||||
}
|
||||
}
|
||||
|
||||
func validateOpsEmailNotificationConfig(cfg *OpsEmailNotificationConfig) error {
|
||||
if cfg == nil {
|
||||
return errors.New("invalid config")
|
||||
}
|
||||
|
||||
if cfg.Alert.RateLimitPerHour < 0 {
|
||||
return errors.New("alert.rate_limit_per_hour must be >= 0")
|
||||
}
|
||||
if cfg.Alert.BatchingWindowSeconds < 0 {
|
||||
return errors.New("alert.batching_window_seconds must be >= 0")
|
||||
}
|
||||
switch strings.TrimSpace(cfg.Alert.MinSeverity) {
|
||||
case "", "critical", "warning", "info":
|
||||
default:
|
||||
return errors.New("alert.min_severity must be one of: critical, warning, info, or empty")
|
||||
}
|
||||
|
||||
if cfg.Report.ErrorDigestMinCount < 0 {
|
||||
return errors.New("report.error_digest_min_count must be >= 0")
|
||||
}
|
||||
if cfg.Report.AccountHealthErrorRateThreshold < 0 || cfg.Report.AccountHealthErrorRateThreshold > 100 {
|
||||
return errors.New("report.account_health_error_rate_threshold must be between 0 and 100")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Alert runtime settings
|
||||
// =========================
|
||||
|
||||
func defaultOpsAlertRuntimeSettings() *OpsAlertRuntimeSettings {
|
||||
return &OpsAlertRuntimeSettings{
|
||||
EvaluationIntervalSeconds: 60,
|
||||
DistributedLock: OpsDistributedLockSettings{
|
||||
Enabled: true,
|
||||
Key: opsAlertEvaluatorLeaderLockKeyDefault,
|
||||
TTLSeconds: int(opsAlertEvaluatorLeaderLockTTLDefault.Seconds()),
|
||||
},
|
||||
Silencing: OpsAlertSilencingSettings{
|
||||
Enabled: false,
|
||||
GlobalUntilRFC3339: "",
|
||||
GlobalReason: "",
|
||||
Entries: []OpsAlertSilenceEntry{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpsDistributedLockSettings(s *OpsDistributedLockSettings, defaultKey string, defaultTTLSeconds int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.Key = strings.TrimSpace(s.Key)
|
||||
if s.Key == "" {
|
||||
s.Key = defaultKey
|
||||
}
|
||||
if s.TTLSeconds <= 0 {
|
||||
s.TTLSeconds = defaultTTLSeconds
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpsAlertSilencingSettings(s *OpsAlertSilencingSettings) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.GlobalUntilRFC3339 = strings.TrimSpace(s.GlobalUntilRFC3339)
|
||||
s.GlobalReason = strings.TrimSpace(s.GlobalReason)
|
||||
if s.Entries == nil {
|
||||
s.Entries = []OpsAlertSilenceEntry{}
|
||||
}
|
||||
for i := range s.Entries {
|
||||
s.Entries[i].UntilRFC3339 = strings.TrimSpace(s.Entries[i].UntilRFC3339)
|
||||
s.Entries[i].Reason = strings.TrimSpace(s.Entries[i].Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func validateOpsDistributedLockSettings(s OpsDistributedLockSettings) error {
|
||||
if strings.TrimSpace(s.Key) == "" {
|
||||
return errors.New("distributed_lock.key is required")
|
||||
}
|
||||
if s.TTLSeconds <= 0 || s.TTLSeconds > int((24*time.Hour).Seconds()) {
|
||||
return errors.New("distributed_lock.ttl_seconds must be between 1 and 86400")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateOpsAlertSilencingSettings(s OpsAlertSilencingSettings) error {
|
||||
parse := func(raw string) error {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, raw); err != nil {
|
||||
return errors.New("silencing time must be RFC3339")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := parse(s.GlobalUntilRFC3339); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range s.Entries {
|
||||
if strings.TrimSpace(entry.UntilRFC3339) == "" {
|
||||
return errors.New("silencing.entries.until_rfc3339 is required")
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, entry.UntilRFC3339); err != nil {
|
||||
return errors.New("silencing.entries.until_rfc3339 must be RFC3339")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetOpsAlertRuntimeSettings(ctx context.Context) (*OpsAlertRuntimeSettings, error) {
|
||||
defaultCfg := defaultOpsAlertRuntimeSettings()
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsAlertRuntimeSettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
|
||||
_ = s.settingRepo.Set(ctx, SettingKeyOpsAlertRuntimeSettings, string(b))
|
||||
}
|
||||
return defaultCfg, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsAlertRuntimeSettings{}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
|
||||
if cfg.EvaluationIntervalSeconds <= 0 {
|
||||
cfg.EvaluationIntervalSeconds = defaultCfg.EvaluationIntervalSeconds
|
||||
}
|
||||
normalizeOpsDistributedLockSettings(&cfg.DistributedLock, opsAlertEvaluatorLeaderLockKeyDefault, defaultCfg.DistributedLock.TTLSeconds)
|
||||
normalizeOpsAlertSilencingSettings(&cfg.Silencing)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateOpsAlertRuntimeSettings(ctx context.Context, cfg *OpsAlertRuntimeSettings) (*OpsAlertRuntimeSettings, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return nil, errors.New("setting repository not initialized")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("invalid config")
|
||||
}
|
||||
|
||||
if cfg.EvaluationIntervalSeconds < 1 || cfg.EvaluationIntervalSeconds > int((24*time.Hour).Seconds()) {
|
||||
return nil, errors.New("evaluation_interval_seconds must be between 1 and 86400")
|
||||
}
|
||||
if cfg.DistributedLock.Enabled {
|
||||
if err := validateOpsDistributedLockSettings(cfg.DistributedLock); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if cfg.Silencing.Enabled {
|
||||
if err := validateOpsAlertSilencingSettings(cfg.Silencing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
defaultCfg := defaultOpsAlertRuntimeSettings()
|
||||
normalizeOpsDistributedLockSettings(&cfg.DistributedLock, opsAlertEvaluatorLeaderLockKeyDefault, defaultCfg.DistributedLock.TTLSeconds)
|
||||
normalizeOpsAlertSilencingSettings(&cfg.Silencing)
|
||||
|
||||
raw, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyOpsAlertRuntimeSettings, string(raw)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return a fresh copy (avoid callers holding pointers into internal slices that may be mutated).
|
||||
updated := &OpsAlertRuntimeSettings{}
|
||||
_ = json.Unmarshal(raw, updated)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Advanced settings
|
||||
// =========================
|
||||
|
||||
func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
||||
return &OpsAdvancedSettings{
|
||||
DataRetention: OpsDataRetentionSettings{
|
||||
CleanupEnabled: false,
|
||||
CleanupSchedule: "0 2 * * *",
|
||||
ErrorLogRetentionDays: 30,
|
||||
MinuteMetricsRetentionDays: 30,
|
||||
HourlyMetricsRetentionDays: 30,
|
||||
},
|
||||
Aggregation: OpsAggregationSettings{
|
||||
AggregationEnabled: false,
|
||||
},
|
||||
IgnoreCountTokensErrors: false,
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.DataRetention.CleanupSchedule = strings.TrimSpace(cfg.DataRetention.CleanupSchedule)
|
||||
if cfg.DataRetention.CleanupSchedule == "" {
|
||||
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
|
||||
}
|
||||
if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
|
||||
cfg.DataRetention.ErrorLogRetentionDays = 30
|
||||
}
|
||||
if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
|
||||
cfg.DataRetention.MinuteMetricsRetentionDays = 30
|
||||
}
|
||||
if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
|
||||
cfg.DataRetention.HourlyMetricsRetentionDays = 30
|
||||
}
|
||||
// Normalize auto refresh interval (default 30 seconds)
|
||||
if cfg.AutoRefreshIntervalSec <= 0 {
|
||||
cfg.AutoRefreshIntervalSec = 30
|
||||
}
|
||||
}
|
||||
|
||||
func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
|
||||
if cfg == nil {
|
||||
return errors.New("invalid config")
|
||||
}
|
||||
if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
|
||||
return errors.New("error_log_retention_days must be between 1 and 365")
|
||||
}
|
||||
if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
|
||||
return errors.New("minute_metrics_retention_days must be between 1 and 365")
|
||||
}
|
||||
if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
|
||||
return errors.New("hourly_metrics_retention_days must be between 1 and 365")
|
||||
}
|
||||
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
|
||||
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetOpsAdvancedSettings(ctx context.Context) (*OpsAdvancedSettings, error) {
|
||||
defaultCfg := defaultOpsAdvancedSettings()
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsAdvancedSettings)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
|
||||
_ = s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(b))
|
||||
}
|
||||
return defaultCfg, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsAdvancedSettings{}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
|
||||
normalizeOpsAdvancedSettings(cfg)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateOpsAdvancedSettings(ctx context.Context, cfg *OpsAdvancedSettings) (*OpsAdvancedSettings, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return nil, errors.New("setting repository not initialized")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("invalid config")
|
||||
}
|
||||
|
||||
if err := validateOpsAdvancedSettings(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizeOpsAdvancedSettings(cfg)
|
||||
raw, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyOpsAdvancedSettings, string(raw)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated := &OpsAdvancedSettings{}
|
||||
_ = json.Unmarshal(raw, updated)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// =========================
|
||||
// Metric thresholds
|
||||
// =========================
|
||||
|
||||
const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
|
||||
|
||||
func defaultOpsMetricThresholds() *OpsMetricThresholds {
|
||||
slaMin := 99.5
|
||||
ttftMax := 500.0
|
||||
reqErrMax := 5.0
|
||||
upstreamErrMax := 5.0
|
||||
return &OpsMetricThresholds{
|
||||
SLAPercentMin: &slaMin,
|
||||
TTFTp99MsMax: &ttftMax,
|
||||
RequestErrorRatePercentMax: &reqErrMax,
|
||||
UpstreamErrorRatePercentMax: &upstreamErrMax,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsService) GetMetricThresholds(ctx context.Context) (*OpsMetricThresholds, error) {
|
||||
defaultCfg := defaultOpsMetricThresholds()
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMetricThresholds)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
if b, mErr := json.Marshal(defaultCfg); mErr == nil {
|
||||
_ = s.settingRepo.Set(ctx, SettingKeyOpsMetricThresholds, string(b))
|
||||
}
|
||||
return defaultCfg, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &OpsMetricThresholds{}
|
||||
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
|
||||
return defaultCfg, nil
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricThresholds) (*OpsMetricThresholds, error) {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return nil, errors.New("setting repository not initialized")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, errors.New("invalid config")
|
||||
}
|
||||
|
||||
// Validate thresholds
|
||||
if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) {
|
||||
return nil, errors.New("sla_percent_min must be between 0 and 100")
|
||||
}
|
||||
if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 {
|
||||
return nil, errors.New("ttft_p99_ms_max must be >= 0")
|
||||
}
|
||||
if cfg.RequestErrorRatePercentMax != nil && (*cfg.RequestErrorRatePercentMax < 0 || *cfg.RequestErrorRatePercentMax > 100) {
|
||||
return nil, errors.New("request_error_rate_percent_max must be between 0 and 100")
|
||||
}
|
||||
if cfg.UpstreamErrorRatePercentMax != nil && (*cfg.UpstreamErrorRatePercentMax < 0 || *cfg.UpstreamErrorRatePercentMax > 100) {
|
||||
return nil, errors.New("upstream_error_rate_percent_max must be between 0 and 100")
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyOpsMetricThresholds, string(raw)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated := &OpsMetricThresholds{}
|
||||
_ = json.Unmarshal(raw, updated)
|
||||
return updated, nil
|
||||
}
|
||||
100
backend/internal/service/ops_settings_models.go
Normal file
100
backend/internal/service/ops_settings_models.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package service
|
||||
|
||||
// Ops settings models stored in DB `settings` table (JSON blobs).
|
||||
|
||||
type OpsEmailNotificationConfig struct {
|
||||
Alert OpsEmailAlertConfig `json:"alert"`
|
||||
Report OpsEmailReportConfig `json:"report"`
|
||||
}
|
||||
|
||||
type OpsEmailAlertConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Recipients []string `json:"recipients"`
|
||||
MinSeverity string `json:"min_severity"`
|
||||
RateLimitPerHour int `json:"rate_limit_per_hour"`
|
||||
BatchingWindowSeconds int `json:"batching_window_seconds"`
|
||||
IncludeResolvedAlerts bool `json:"include_resolved_alerts"`
|
||||
}
|
||||
|
||||
type OpsEmailReportConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Recipients []string `json:"recipients"`
|
||||
DailySummaryEnabled bool `json:"daily_summary_enabled"`
|
||||
DailySummarySchedule string `json:"daily_summary_schedule"`
|
||||
WeeklySummaryEnabled bool `json:"weekly_summary_enabled"`
|
||||
WeeklySummarySchedule string `json:"weekly_summary_schedule"`
|
||||
ErrorDigestEnabled bool `json:"error_digest_enabled"`
|
||||
ErrorDigestSchedule string `json:"error_digest_schedule"`
|
||||
ErrorDigestMinCount int `json:"error_digest_min_count"`
|
||||
AccountHealthEnabled bool `json:"account_health_enabled"`
|
||||
AccountHealthSchedule string `json:"account_health_schedule"`
|
||||
AccountHealthErrorRateThreshold float64 `json:"account_health_error_rate_threshold"`
|
||||
}
|
||||
|
||||
// OpsEmailNotificationConfigUpdateRequest allows partial updates, while the
|
||||
// frontend can still send the full config shape.
|
||||
type OpsEmailNotificationConfigUpdateRequest struct {
|
||||
Alert *OpsEmailAlertConfig `json:"alert"`
|
||||
Report *OpsEmailReportConfig `json:"report"`
|
||||
}
|
||||
|
||||
type OpsDistributedLockSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Key string `json:"key"`
|
||||
TTLSeconds int `json:"ttl_seconds"`
|
||||
}
|
||||
|
||||
type OpsAlertSilenceEntry struct {
|
||||
RuleID *int64 `json:"rule_id,omitempty"`
|
||||
Severities []string `json:"severities,omitempty"`
|
||||
|
||||
UntilRFC3339 string `json:"until_rfc3339"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type OpsAlertSilencingSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
GlobalUntilRFC3339 string `json:"global_until_rfc3339"`
|
||||
GlobalReason string `json:"global_reason"`
|
||||
|
||||
Entries []OpsAlertSilenceEntry `json:"entries,omitempty"`
|
||||
}
|
||||
|
||||
type OpsMetricThresholds struct {
|
||||
SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红
|
||||
TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红
|
||||
RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红
|
||||
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
|
||||
}
|
||||
|
||||
type OpsAlertRuntimeSettings struct {
|
||||
EvaluationIntervalSeconds int `json:"evaluation_interval_seconds"`
|
||||
|
||||
DistributedLock OpsDistributedLockSettings `json:"distributed_lock"`
|
||||
Silencing OpsAlertSilencingSettings `json:"silencing"`
|
||||
Thresholds OpsMetricThresholds `json:"thresholds"` // 指标阈值配置
|
||||
}
|
||||
|
||||
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
|
||||
type OpsAdvancedSettings struct {
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
type OpsDataRetentionSettings struct {
|
||||
CleanupEnabled bool `json:"cleanup_enabled"`
|
||||
CleanupSchedule string `json:"cleanup_schedule"`
|
||||
ErrorLogRetentionDays int `json:"error_log_retention_days"`
|
||||
MinuteMetricsRetentionDays int `json:"minute_metrics_retention_days"`
|
||||
HourlyMetricsRetentionDays int `json:"hourly_metrics_retention_days"`
|
||||
}
|
||||
|
||||
type OpsAggregationSettings struct {
|
||||
AggregationEnabled bool `json:"aggregation_enabled"`
|
||||
}
|
||||
65
backend/internal/service/ops_trend_models.go
Normal file
65
backend/internal/service/ops_trend_models.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type OpsThroughputTrendPoint struct {
|
||||
BucketStart time.Time `json:"bucket_start"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
QPS float64 `json:"qps"`
|
||||
TPS float64 `json:"tps"`
|
||||
}
|
||||
|
||||
type OpsThroughputPlatformBreakdownItem struct {
|
||||
Platform string `json:"platform"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
}
|
||||
|
||||
type OpsThroughputGroupBreakdownItem struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
}
|
||||
|
||||
type OpsThroughputTrendResponse struct {
|
||||
Bucket string `json:"bucket"`
|
||||
|
||||
Points []*OpsThroughputTrendPoint `json:"points"`
|
||||
|
||||
// Optional drilldown helpers:
|
||||
// - When no platform/group is selected: returns totals by platform.
|
||||
// - When platform is selected but group is not: returns top groups in that platform.
|
||||
ByPlatform []*OpsThroughputPlatformBreakdownItem `json:"by_platform,omitempty"`
|
||||
TopGroups []*OpsThroughputGroupBreakdownItem `json:"top_groups,omitempty"`
|
||||
}
|
||||
|
||||
type OpsErrorTrendPoint struct {
|
||||
BucketStart time.Time `json:"bucket_start"`
|
||||
|
||||
ErrorCountTotal int64 `json:"error_count_total"`
|
||||
BusinessLimitedCount int64 `json:"business_limited_count"`
|
||||
ErrorCountSLA int64 `json:"error_count_sla"`
|
||||
|
||||
UpstreamErrorCountExcl429529 int64 `json:"upstream_error_count_excl_429_529"`
|
||||
Upstream429Count int64 `json:"upstream_429_count"`
|
||||
Upstream529Count int64 `json:"upstream_529_count"`
|
||||
}
|
||||
|
||||
type OpsErrorTrendResponse struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Points []*OpsErrorTrendPoint `json:"points"`
|
||||
}
|
||||
|
||||
type OpsErrorDistributionItem struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Total int64 `json:"total"`
|
||||
SLA int64 `json:"sla"`
|
||||
BusinessLimited int64 `json:"business_limited"`
|
||||
}
|
||||
|
||||
type OpsErrorDistributionResponse struct {
|
||||
Total int64 `json:"total"`
|
||||
Items []*OpsErrorDistributionItem `json:"items"`
|
||||
}
|
||||
26
backend/internal/service/ops_trends.go
Normal file
26
backend/internal/service/ops_trends.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds)
|
||||
}
|
||||
131
backend/internal/service/ops_upstream_context.go
Normal file
131
backend/internal/service/ops_upstream_context.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Gin context keys used by Ops error logger for capturing upstream error details.
|
||||
// These keys are set by gateway services and consumed by handler/ops_error_logger.go.
|
||||
const (
|
||||
OpsUpstreamStatusCodeKey = "ops_upstream_status_code"
|
||||
OpsUpstreamErrorMessageKey = "ops_upstream_error_message"
|
||||
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
|
||||
OpsUpstreamErrorsKey = "ops_upstream_errors"
|
||||
|
||||
// Best-effort capture of the current upstream request body so ops can
|
||||
// retry the specific upstream attempt (not just the client request).
|
||||
// This value is sanitized+trimmed before being persisted.
|
||||
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
||||
)
|
||||
|
||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if upstreamStatusCode > 0 {
|
||||
c.Set(OpsUpstreamStatusCodeKey, upstreamStatusCode)
|
||||
}
|
||||
if msg := strings.TrimSpace(upstreamMessage); msg != "" {
|
||||
c.Set(OpsUpstreamErrorMessageKey, msg)
|
||||
}
|
||||
if detail := strings.TrimSpace(upstreamDetail); detail != "" {
|
||||
c.Set(OpsUpstreamErrorDetailKey, detail)
|
||||
}
|
||||
}
|
||||
|
||||
// OpsUpstreamErrorEvent describes one upstream error attempt during a single gateway request.
|
||||
// It is stored in ops_error_logs.upstream_errors as a JSON array.
|
||||
type OpsUpstreamErrorEvent struct {
|
||||
AtUnixMs int64 `json:"at_unix_ms,omitempty"`
|
||||
|
||||
// Context
|
||||
Platform string `json:"platform,omitempty"`
|
||||
AccountID int64 `json:"account_id,omitempty"`
|
||||
AccountName string `json:"account_name,omitempty"`
|
||||
|
||||
// Outcome
|
||||
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
||||
|
||||
// Best-effort upstream request capture (sanitized+trimmed).
|
||||
// Required for retrying a specific upstream attempt.
|
||||
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
|
||||
|
||||
// Best-effort upstream response capture (sanitized+trimmed).
|
||||
UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
|
||||
|
||||
// Kind: http_error | request_error | retry_exhausted | failover
|
||||
Kind string `json:"kind,omitempty"`
|
||||
|
||||
Message string `json:"message,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if ev.AtUnixMs <= 0 {
|
||||
ev.AtUnixMs = time.Now().UnixMilli()
|
||||
}
|
||||
ev.Platform = strings.TrimSpace(ev.Platform)
|
||||
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
|
||||
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
|
||||
ev.Kind = strings.TrimSpace(ev.Kind)
|
||||
ev.Message = strings.TrimSpace(ev.Message)
|
||||
ev.Detail = strings.TrimSpace(ev.Detail)
|
||||
if ev.Message != "" {
|
||||
ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
|
||||
}
|
||||
|
||||
// If the caller didn't explicitly pass upstream request body but the gateway
|
||||
// stored it on the context, attach it so ops can retry this specific attempt.
|
||||
if ev.UpstreamRequestBody == "" {
|
||||
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var existing []*OpsUpstreamErrorEvent
|
||||
if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
|
||||
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
|
||||
existing = arr
|
||||
}
|
||||
}
|
||||
|
||||
evCopy := ev
|
||||
existing = append(existing, &evCopy)
|
||||
c.Set(OpsUpstreamErrorsKey, existing)
|
||||
}
|
||||
|
||||
func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
|
||||
if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Ensure we always store a valid JSON value.
|
||||
raw, err := json.Marshal(events)
|
||||
if err != nil || len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
s := string(raw)
|
||||
return &s
|
||||
}
|
||||
|
||||
func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return []*OpsUpstreamErrorEvent{}, nil
|
||||
}
|
||||
var out []*OpsUpstreamErrorEvent
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
24
backend/internal/service/ops_window_stats.go
Normal file
24
backend/internal/service/ops_window_stats.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// GetWindowStats returns lightweight request/token counts for the provided window.
|
||||
// It is intended for realtime sampling (e.g. WebSocket QPS push) without computing percentiles/peaks.
|
||||
func (s *OpsService) GetWindowStats(ctx context.Context, startTime, endTime time.Time) (*OpsWindowStats, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
filter := &OpsDashboardFilter{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
}
|
||||
return s.opsRepo.GetWindowStats(ctx, filter)
|
||||
}
|
||||
73
backend/internal/service/promo_code.go
Normal file
73
backend/internal/service/promo_code.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
Status string
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联
|
||||
UsageRecords []PromoCodeUsage
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64
|
||||
PromoCodeID int64
|
||||
UserID int64
|
||||
BonusAmount float64
|
||||
UsedAt time.Time
|
||||
|
||||
// 关联
|
||||
PromoCode *PromoCode
|
||||
User *User
|
||||
}
|
||||
|
||||
// CanUse 检查优惠码是否可用
|
||||
func (p *PromoCode) CanUse() bool {
|
||||
if p.Status != PromoCodeStatusActive {
|
||||
return false
|
||||
}
|
||||
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsExpired 检查是否已过期
|
||||
func (p *PromoCode) IsExpired() bool {
|
||||
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
|
||||
}
|
||||
|
||||
// CreatePromoCodeInput 创建优惠码输入
|
||||
type CreatePromoCodeInput struct {
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
}
|
||||
|
||||
// UpdatePromoCodeInput 更新优惠码输入
|
||||
type UpdatePromoCodeInput struct {
|
||||
Code *string
|
||||
BonusAmount *float64
|
||||
MaxUses *int
|
||||
Status *string
|
||||
ExpiresAt *time.Time
|
||||
Notes *string
|
||||
}
|
||||
30
backend/internal/service/promo_code_repository.go
Normal file
30
backend/internal/service/promo_code_repository.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// PromoCodeRepository 优惠码仓储接口
|
||||
type PromoCodeRepository interface {
|
||||
// 基础 CRUD
|
||||
Create(ctx context.Context, code *PromoCode) error
|
||||
GetByID(ctx context.Context, id int64) (*PromoCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*PromoCode, error)
|
||||
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
|
||||
Update(ctx context.Context, code *PromoCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// 列表查询
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
|
||||
// 使用记录
|
||||
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
|
||||
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
|
||||
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
|
||||
|
||||
// 计数操作
|
||||
IncrementUsedCount(ctx context.Context, id int64) error
|
||||
}
|
||||
268
backend/internal/service/promo_service.go
Normal file
268
backend/internal/service/promo_service.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
|
||||
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
|
||||
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
|
||||
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
|
||||
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
|
||||
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
|
||||
)
|
||||
|
||||
// PromoService 优惠码服务
|
||||
type PromoService struct {
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewPromoService 创建优惠码服务实例
|
||||
func NewPromoService(
|
||||
promoRepo PromoCodeRepository,
|
||||
userRepo UserRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) *PromoService {
|
||||
return &PromoService{
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(注册前调用)
|
||||
// 返回 nil, nil 表示空码(不报错)
|
||||
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil, nil // 空码不报错,直接返回
|
||||
}
|
||||
|
||||
promoCode, err := s.promoRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
// 保留原始错误类型,不要统一映射为 NotFound
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// validatePromoCodeStatus 验证优惠码状态
|
||||
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
|
||||
if !promoCode.CanUse() {
|
||||
if promoCode.IsExpired() {
|
||||
return ErrPromoCodeExpired
|
||||
}
|
||||
if promoCode.Status == PromoCodeStatusDisabled {
|
||||
return ErrPromoCodeDisabled
|
||||
}
|
||||
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
|
||||
return ErrPromoCodeMaxUsed
|
||||
}
|
||||
return ErrPromoCodeInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPromoCode 应用优惠码(注册成功后调用)
|
||||
// 使用事务和行锁确保并发安全
|
||||
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 在事务中获取并锁定优惠码记录(FOR UPDATE)
|
||||
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中验证优惠码状态
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中检查用户是否已使用过此优惠码
|
||||
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing usage: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return ErrPromoCodeAlreadyUsed
|
||||
}
|
||||
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
|
||||
return fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
|
||||
// 创建使用记录
|
||||
usage := &PromoCodeUsage{
|
||||
PromoCodeID: promoCode.ID,
|
||||
UserID: userID,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
UsedAt: time.Now(),
|
||||
}
|
||||
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
|
||||
return fmt.Errorf("create usage record: %w", err)
|
||||
}
|
||||
|
||||
// 增加使用次数
|
||||
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
|
||||
return fmt.Errorf("increment used count: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
|
||||
if bonusAmount == 0 || s.authCacheInvalidator == nil {
|
||||
return
|
||||
}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机优惠码
|
||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(bytes)), nil
|
||||
}
|
||||
|
||||
// Create 创建优惠码
|
||||
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
|
||||
code := strings.TrimSpace(input.Code)
|
||||
if code == "" {
|
||||
// 自动生成
|
||||
var err error
|
||||
code, err = s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promoCode := &PromoCode{
|
||||
Code: strings.ToUpper(code),
|
||||
BonusAmount: input.BonusAmount,
|
||||
MaxUses: input.MaxUses,
|
||||
UsedCount: 0,
|
||||
Status: PromoCodeStatusActive,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
Notes: input.Notes,
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("create promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取优惠码
|
||||
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
|
||||
code, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// Update 更新优惠码
|
||||
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
|
||||
promoCode, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Code != nil {
|
||||
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
|
||||
}
|
||||
if input.BonusAmount != nil {
|
||||
promoCode.BonusAmount = *input.BonusAmount
|
||||
}
|
||||
if input.MaxUses != nil {
|
||||
promoCode.MaxUses = *input.MaxUses
|
||||
}
|
||||
if input.Status != nil {
|
||||
promoCode.Status = *input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
promoCode.ExpiresAt = input.ExpiresAt
|
||||
}
|
||||
if input.Notes != nil {
|
||||
promoCode.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("update promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// Delete 删除优惠码
|
||||
func (s *PromoService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.promoRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete promo code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取优惠码列表
|
||||
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListWithFilters(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// ListUsages 获取使用记录
|
||||
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
|
||||
}
|
||||
275
backend/internal/service/prompts/codex_cli_instructions.md
Normal file
275
backend/internal/service/prompts/codex_cli_instructions.md
Normal file
@@ -0,0 +1,275 @@
|
||||
You are a coding agent running in the Codex CLI, a terminal-based coding assistant. Codex CLI is an open source project led by OpenAI. You are expected to be precise, safe, and helpful.
|
||||
|
||||
Your capabilities:
|
||||
|
||||
- Receive user prompts and other context provided by the harness, such as files in the workspace.
|
||||
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
|
||||
- Emit function calls to run terminal commands and apply patches. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
|
||||
|
||||
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
|
||||
|
||||
# How you work
|
||||
|
||||
## Personality
|
||||
|
||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
||||
|
||||
# AGENTS.md spec
|
||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
||||
- Instructions in AGENTS.md files:
|
||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
||||
|
||||
## Responsiveness
|
||||
|
||||
### Preamble messages
|
||||
|
||||
Before making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples:
|
||||
|
||||
- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
|
||||
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (8–12 words for quick updates).
|
||||
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions.
|
||||
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
|
||||
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action.
|
||||
|
||||
**Examples:**
|
||||
|
||||
- “I’ve explored the repo; now checking the API route definitions.”
|
||||
- “Next, I’ll patch the config and update the related tests.”
|
||||
- “I’m about to scaffold the CLI commands and helper functions.”
|
||||
- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.”
|
||||
- “Config’s looking tidy. Next up is patching helpers to keep things in sync.”
|
||||
- “Finished poking at the DB gateway. I will now chase down error handling.”
|
||||
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
|
||||
- “Spotted a clever caching util; now hunting where it gets used.”
|
||||
|
||||
## Planning
|
||||
|
||||
You have access to an `update_plan` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
|
||||
|
||||
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
|
||||
|
||||
Do not repeat the full contents of the plan after an `update_plan` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
|
||||
|
||||
Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step. It may be the case that you complete all steps in your plan after a single pass of implementation. If this is the case, you can simply mark all the planned steps as completed. Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
|
||||
|
||||
Use a plan when:
|
||||
|
||||
- The task is non-trivial and will require multiple actions over a long time horizon.
|
||||
- There are logical phases or dependencies where sequencing matters.
|
||||
- The work has ambiguity that benefits from outlining high-level goals.
|
||||
- You want intermediate checkpoints for feedback and validation.
|
||||
- When the user asked you to do more than one thing in a single prompt
|
||||
- The user has asked you to use the plan tool (aka "TODOs")
|
||||
- You generate additional steps while working, and plan to do them before yielding to the user
|
||||
|
||||
### Examples
|
||||
|
||||
**High-quality plans**
|
||||
|
||||
Example 1:
|
||||
|
||||
1. Add CLI entry with file args
|
||||
2. Parse Markdown via CommonMark library
|
||||
3. Apply semantic HTML template
|
||||
4. Handle code blocks, images, links
|
||||
5. Add error handling for invalid files
|
||||
|
||||
Example 2:
|
||||
|
||||
1. Define CSS variables for colors
|
||||
2. Add toggle with localStorage state
|
||||
3. Refactor components to use variables
|
||||
4. Verify all views for readability
|
||||
5. Add smooth theme-change transition
|
||||
|
||||
Example 3:
|
||||
|
||||
1. Set up Node.js + WebSocket server
|
||||
2. Add join/leave broadcast events
|
||||
3. Implement messaging with timestamps
|
||||
4. Add usernames + mention highlighting
|
||||
5. Persist messages in lightweight DB
|
||||
6. Add typing indicators + unread count
|
||||
|
||||
**Low-quality plans**
|
||||
|
||||
Example 1:
|
||||
|
||||
1. Create CLI tool
|
||||
2. Add Markdown parser
|
||||
3. Convert to HTML
|
||||
|
||||
Example 2:
|
||||
|
||||
1. Add dark mode toggle
|
||||
2. Save preference
|
||||
3. Make styles look good
|
||||
|
||||
Example 3:
|
||||
|
||||
1. Create single-file HTML game
|
||||
2. Run quick sanity check
|
||||
3. Summarize usage instructions
|
||||
|
||||
If you need to write a plan, only write high quality plans, not low quality ones.
|
||||
|
||||
## Task execution
|
||||
|
||||
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
|
||||
|
||||
You MUST adhere to the following criteria when solving queries:
|
||||
|
||||
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
|
||||
- Analyzing code for vulnerabilities is allowed.
|
||||
- Showing user code and tool call details is allowed.
|
||||
- Use the `apply_patch` tool to edit files (NEVER try `applypatch` or `apply-patch`, only `apply_patch`): {"command":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]}
|
||||
|
||||
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
|
||||
|
||||
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
|
||||
- Avoid unneeded complexity in your solution.
|
||||
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
||||
- Update documentation as necessary.
|
||||
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
|
||||
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
|
||||
- NEVER add copyright or license headers unless specifically requested.
|
||||
- Do not waste tokens by re-reading files after calling `apply_patch` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
|
||||
- Do not `git commit` your changes or create new git branches unless explicitly requested.
|
||||
- Do not add inline comments within code unless explicitly requested.
|
||||
- Do not use one-letter variable names unless explicitly requested.
|
||||
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
|
||||
|
||||
## Validating your work
|
||||
|
||||
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
|
||||
|
||||
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
|
||||
|
||||
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
|
||||
|
||||
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
||||
|
||||
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
|
||||
|
||||
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
|
||||
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
|
||||
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
|
||||
|
||||
## Ambition vs. precision
|
||||
|
||||
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
|
||||
|
||||
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
|
||||
|
||||
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
|
||||
|
||||
## Sharing progress updates
|
||||
|
||||
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
|
||||
|
||||
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
|
||||
|
||||
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
|
||||
|
||||
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multi-section structured responses for results that need grouping or explanation.
|
||||
|
||||
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `apply_patch`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
|
||||
|
||||
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
|
||||
|
||||
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
**Section Headers**
|
||||
|
||||
- Use only when they improve clarity — they are not mandatory for every answer.
|
||||
- Choose descriptive names that fit the content
|
||||
- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
|
||||
- Leave no blank line before the first bullet under a header.
|
||||
- Section headers should only be used where they genuinely improve scanability; avoid fragmenting the answer.
|
||||
|
||||
**Bullets**
|
||||
|
||||
- Use `-` followed by a space for every bullet.
|
||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
||||
- Group into short lists (4–6 bullets) ordered by importance.
|
||||
- Use consistent keyword phrasing and formatting across sections.
|
||||
|
||||
**Monospace**
|
||||
|
||||
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
|
||||
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
|
||||
- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``).
|
||||
|
||||
**File References**
|
||||
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
|
||||
**Structure**
|
||||
|
||||
- Place related bullets together; don’t mix unrelated concepts in the same section.
|
||||
- Order sections from general → specific → supporting info.
|
||||
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
|
||||
- Match structure to complexity:
|
||||
- Multi-part or detailed results → use clear headers and grouped bullets.
|
||||
- Simple results → minimal headers, possibly just a short list or paragraph.
|
||||
|
||||
**Tone**
|
||||
|
||||
- Keep the voice collaborative and natural, like a coding partner handing off work.
|
||||
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
|
||||
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
|
||||
- Keep descriptions self-contained; don’t refer to “above” or “below”.
|
||||
- Use parallel structure in lists for consistency.
|
||||
|
||||
**Don’t**
|
||||
|
||||
- Don’t use literal words “bold” or “monospace” in the content.
|
||||
- Don’t nest bullets or create deep hierarchies.
|
||||
- Don’t output ANSI escape codes directly — the CLI renderer applies them.
|
||||
- Don’t cram unrelated keywords into a single bullet; split for clarity.
|
||||
- Don’t let keyword lists run long — wrap or reformat for scanability.
|
||||
|
||||
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
|
||||
|
||||
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
|
||||
|
||||
# Tool Guidelines
|
||||
|
||||
## Shell commands
|
||||
|
||||
When using the shell, you must adhere to the following guidelines:
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
- Do not use python scripts to attempt to output larger chunks of a file.
|
||||
|
||||
## `update_plan`
|
||||
|
||||
A tool named `update_plan` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task.
|
||||
|
||||
To create a new plan, call `update_plan` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
|
||||
|
||||
When steps have been completed, use `update_plan` to mark each finished step as `completed` and the next step you are working on as `in_progress`. There should always be exactly one `in_progress` step until everything is done. You can mark multiple items as complete in a single `update_plan` call.
|
||||
|
||||
If all steps are complete, ensure you call `update_plan` to mark all steps as `completed`.
|
||||
122
backend/internal/service/prompts/codex_opencode_bridge.txt
Normal file
122
backend/internal/service/prompts/codex_opencode_bridge.txt
Normal file
@@ -0,0 +1,122 @@
|
||||
# Codex Running in OpenCode
|
||||
|
||||
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
|
||||
|
||||
## CRITICAL: Tool Replacements
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||
- NEVER use: apply_patch, applyPatch
|
||||
- ALWAYS use: edit tool for ALL file modifications
|
||||
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||
</critical_rule>
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||
- NEVER use: update_plan, updatePlan, read_plan, readPlan
|
||||
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
|
||||
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||
</critical_rule>
|
||||
|
||||
## Available OpenCode Tools
|
||||
|
||||
**File Operations:**
|
||||
- `write` - Create new files
|
||||
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
|
||||
- `edit` - Modify existing files (REPLACES apply_patch)
|
||||
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
|
||||
- `read` - Read file contents
|
||||
|
||||
**Search/Discovery:**
|
||||
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
|
||||
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
|
||||
- `list` - List directories (requires absolute paths)
|
||||
|
||||
**Execution:**
|
||||
- `bash` - Run shell commands
|
||||
- No workdir parameter; do not include it in tool calls.
|
||||
- Always include a short description for the command.
|
||||
- Do not use cd; use absolute paths in commands.
|
||||
- Quote paths containing spaces with double quotes.
|
||||
- Chain multiple commands with ';' or '&&'; avoid newlines.
|
||||
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
|
||||
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
|
||||
- For deletions (rm), verify by listing parent dir with `list`.
|
||||
|
||||
**Network:**
|
||||
- `webfetch` - Fetch web content
|
||||
- Use fully-formed URLs (http/https; http auto-upgrades to https).
|
||||
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
|
||||
- Read-only; short cache window.
|
||||
|
||||
**Task Management:**
|
||||
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
|
||||
- `todoread` - Read current plan
|
||||
|
||||
## Substitution Rules
|
||||
|
||||
Base instruction says: You MUST use instead:
|
||||
apply_patch → edit
|
||||
update_plan → todowrite
|
||||
read_plan → todoread
|
||||
|
||||
**Path Usage:** Use per-tool conventions to avoid conflicts:
|
||||
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
|
||||
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
|
||||
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
|
||||
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before file/plan modifications:
|
||||
1. Am I using "edit" NOT "apply_patch"?
|
||||
2. Am I using "todowrite" NOT "update_plan"?
|
||||
3. Is this tool in the approved list above?
|
||||
4. Am I following each tool's path requirements?
|
||||
|
||||
If ANY answer is NO → STOP and correct before proceeding.
|
||||
|
||||
## OpenCode Working Style
|
||||
|
||||
**Communication:**
|
||||
- Send brief preambles (8-12 words) before tool calls, building on prior context
|
||||
- Provide progress updates during longer tasks
|
||||
|
||||
**Execution:**
|
||||
- Keep working autonomously until query is fully resolved before yielding
|
||||
- Don't return to user with partial solutions
|
||||
|
||||
**Code Approach:**
|
||||
- New projects: Be ambitious and creative
|
||||
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
|
||||
|
||||
**Testing:**
|
||||
- If tests exist: Start specific to your changes, then broader validation
|
||||
|
||||
## Advanced Tools
|
||||
|
||||
**Task Tool (Sub-Agents):**
|
||||
- Use the Task tool (functions.task) to launch sub-agents
|
||||
- Check the Task tool description for current agent types and their capabilities
|
||||
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
|
||||
- The agent list is dynamically generated - refer to tool schema for available agents
|
||||
|
||||
**Parallelization:**
|
||||
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
|
||||
- Reserve sequential calls for ordered or data-dependent steps.
|
||||
|
||||
**MCP Tools:**
|
||||
- Model Context Protocol servers provide additional capabilities
|
||||
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
|
||||
- Check your available tools for MCP integrations
|
||||
- Use when the tool's functionality matches your task needs
|
||||
|
||||
## What Remains from Codex
|
||||
|
||||
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
|
||||
|
||||
## Approvals & Safety
|
||||
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
|
||||
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
|
||||
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
|
||||
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).
|
||||
63
backend/internal/service/prompts/tool_remap_message.txt
Normal file
63
backend/internal/service/prompts/tool_remap_message.txt
Normal file
@@ -0,0 +1,63 @@
|
||||
<user_instructions priority="0">
|
||||
<environment_override priority="0">
|
||||
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
|
||||
</environment_override>
|
||||
|
||||
<tool_replacements priority="0">
|
||||
<critical_rule priority="0">
|
||||
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||
- NEVER use: apply_patch, applyPatch
|
||||
- ALWAYS use: edit tool for ALL file modifications
|
||||
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||
</critical_rule>
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||
- NEVER use: update_plan, updatePlan
|
||||
- ALWAYS use: todowrite for ALL task/plan operations
|
||||
- Use todoread to read current plan
|
||||
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||
</critical_rule>
|
||||
</tool_replacements>
|
||||
|
||||
<available_tools priority="0">
|
||||
File Operations:
|
||||
• write - Create new files
|
||||
• edit - Modify existing files (REPLACES apply_patch)
|
||||
• patch - Apply diff patches
|
||||
• read - Read file contents
|
||||
|
||||
Search/Discovery:
|
||||
• grep - Search file contents
|
||||
• glob - Find files by pattern
|
||||
• list - List directories (use relative paths)
|
||||
|
||||
Execution:
|
||||
• bash - Run shell commands
|
||||
|
||||
Network:
|
||||
• webfetch - Fetch web content
|
||||
|
||||
Task Management:
|
||||
• todowrite - Manage tasks/plans (REPLACES update_plan)
|
||||
• todoread - Read current plan
|
||||
</available_tools>
|
||||
|
||||
<substitution_rules priority="0">
|
||||
Base instruction says: You MUST use instead:
|
||||
apply_patch → edit
|
||||
update_plan → todowrite
|
||||
read_plan → todoread
|
||||
absolute paths → relative paths
|
||||
</substitution_rules>
|
||||
|
||||
<verification_checklist priority="0">
|
||||
Before file/plan modifications:
|
||||
1. Am I using "edit" NOT "apply_patch"?
|
||||
2. Am I using "todowrite" NOT "update_plan"?
|
||||
3. Is this tool in the approved list above?
|
||||
4. Am I using relative paths?
|
||||
|
||||
If ANY answer is NO → STOP and correct before proceeding.
|
||||
</verification_checklist>
|
||||
</user_instructions>
|
||||
@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64
|
||||
AccountCount int64
|
||||
LatencyMs *int64
|
||||
LatencyStatus string
|
||||
LatencyMessage string
|
||||
IPAddress string
|
||||
Country string
|
||||
CountryCode string
|
||||
Region string
|
||||
City string
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
ID int64
|
||||
Name string
|
||||
Platform string
|
||||
Type string
|
||||
Notes *string
|
||||
}
|
||||
|
||||
23
backend/internal/service/proxy_latency_cache.go
Normal file
23
backend/internal/service/proxy_latency_cache.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyLatencyInfo struct {
|
||||
Success bool `json:"success"`
|
||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ProxyLatencyCache interface {
|
||||
GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*ProxyLatencyInfo, error)
|
||||
SetProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) error
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
|
||||
ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts")
|
||||
)
|
||||
|
||||
type ProxyRepository interface {
|
||||
@@ -26,6 +27,7 @@ type ProxyRepository interface {
|
||||
|
||||
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
|
||||
ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||
}
|
||||
|
||||
// CreateProxyRequest 创建代理请求
|
||||
|
||||
@@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -15,13 +15,16 @@ import (
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
timeoutCounterCache TimeoutCounterCache
|
||||
settingService *SettingService
|
||||
tokenCacheInvalidator TokenCacheInvalidator
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
@@ -44,43 +47,105 @@ func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogReposi
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeoutCounterCache 设置超时计数器缓存(可选依赖)
|
||||
func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
|
||||
s.timeoutCounterCache = cache
|
||||
}
|
||||
|
||||
// SetSettingService 设置系统设置服务(可选依赖)
|
||||
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
|
||||
s.settingService = settingService
|
||||
}
|
||||
|
||||
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
|
||||
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
|
||||
s.tokenCacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
// apikey 类型账号:检查自定义错误码配置
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
tempMatched := false
|
||||
if statusCode != 401 {
|
||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg != "" {
|
||||
upstreamMsg = truncateForLog([]byte(upstreamMsg), 512)
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
// 认证失败:停止调度,记录错误
|
||||
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
||||
} else {
|
||||
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
||||
}
|
||||
}
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 402:
|
||||
// 支付要求:余额不足或计费问题,停止调度
|
||||
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
|
||||
msg := "Payment required (402): insufficient balance or billing issue"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Payment required (402): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 403:
|
||||
// 禁止访问:停止调度,记录错误
|
||||
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
|
||||
msg := "Access forbidden (403): account may be suspended or lack permissions"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Access forbidden (403): " + upstreamMsg
|
||||
}
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers)
|
||||
s.handle429(ctx, account, headers, responseBody)
|
||||
shouldDisable = false
|
||||
case 529:
|
||||
s.handle529(ctx, account)
|
||||
shouldDisable = false
|
||||
default:
|
||||
// 其他5xx错误:记录但不停止调度
|
||||
if statusCode >= 500 {
|
||||
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
|
||||
// 自定义错误码启用时:在列表中的错误码都应该停止调度
|
||||
if customErrorCodesEnabled {
|
||||
msg := "Custom error code triggered"
|
||||
if upstreamMsg != "" {
|
||||
msg = upstreamMsg
|
||||
}
|
||||
s.handleCustomErrorCode(ctx, account, statusCode, msg)
|
||||
shouldDisable = true
|
||||
} else if statusCode >= 500 {
|
||||
// 未启用自定义错误码时:仅记录5xx错误
|
||||
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
|
||||
shouldDisable = false
|
||||
}
|
||||
shouldDisable = false
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
@@ -125,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -150,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
// NOTE:
|
||||
// - This is a local precheck to reduce upstream 429s.
|
||||
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
@@ -172,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
|
||||
if limit > 0 {
|
||||
start := now.Truncate(time.Minute)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -193,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
if used >= limit {
|
||||
resetAt := start.Add(time.Minute)
|
||||
// Do not persist "rate limited" status from local precheck. See note above.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
@@ -250,22 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
|
||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 解析重置时间戳
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
if resetTimestamp == "" {
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
|
||||
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
|
||||
} else {
|
||||
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -273,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 解析Unix时间戳
|
||||
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
|
||||
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
|
||||
} else {
|
||||
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
|
||||
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
|
||||
return
|
||||
}
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -293,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
|
||||
if account == nil || account.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(msg, "sonnet")
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
@@ -309,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Account %d overloaded until %v", account.ID, until)
|
||||
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
||||
@@ -336,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
end := start.Add(5 * time.Hour)
|
||||
windowStart = &start
|
||||
windowEnd = &end
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
||||
}
|
||||
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
|
||||
if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.accountRepo.ClearModelRateLimits(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
@@ -365,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
}
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
|
||||
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
|
||||
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -412,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
|
||||
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
|
||||
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
|
||||
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -538,3 +652,124 @@ func truncateTempUnschedMessage(body []byte, maxBytes int) string {
|
||||
}
|
||||
return strings.TrimSpace(string(body))
|
||||
}
|
||||
|
||||
// HandleStreamTimeout 处理流数据超时
|
||||
// 根据系统设置决定是否标记账户为临时不可调度或错误状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Account, model string) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取系统设置
|
||||
if s.settingService == nil {
|
||||
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if !settings.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if settings.Action == StreamTimeoutActionNone {
|
||||
return false
|
||||
}
|
||||
|
||||
// 增加超时计数
|
||||
var count int64 = 1
|
||||
if s.timeoutCounterCache != nil {
|
||||
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
|
||||
if err != nil {
|
||||
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
|
||||
// 继续处理,使用 count=1
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
|
||||
|
||||
// 检查是否达到阈值
|
||||
if count < int64(settings.ThresholdCount) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 达到阈值,执行相应操作
|
||||
switch settings.Action {
|
||||
case StreamTimeoutActionTempUnsched:
|
||||
return s.triggerStreamTimeoutTempUnsched(ctx, account, settings, model)
|
||||
case StreamTimeoutActionError:
|
||||
return s.triggerStreamTimeoutError(ctx, account, model)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// triggerStreamTimeoutTempUnsched 触发流超时临时不可调度
|
||||
func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, account *Account, settings *StreamTimeoutSettings, model string) bool {
|
||||
now := time.Now()
|
||||
until := now.Add(time.Duration(settings.TempUnschedMinutes) * time.Minute)
|
||||
|
||||
state := &TempUnschedState{
|
||||
UntilUnix: until.Unix(),
|
||||
TriggeredAtUnix: now.Unix(),
|
||||
StatusCode: 0, // 超时没有状态码
|
||||
MatchedKeyword: "stream_timeout",
|
||||
RuleIndex: -1, // 表示系统级规则
|
||||
ErrorMessage: "Stream data interval timeout for model: " + model,
|
||||
}
|
||||
|
||||
reason := ""
|
||||
if raw, err := json.Marshal(state); err == nil {
|
||||
reason = string(raw)
|
||||
}
|
||||
if reason == "" {
|
||||
reason = state.ErrorMessage
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 重置超时计数
|
||||
if s.timeoutCounterCache != nil {
|
||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
|
||||
return true
|
||||
}
|
||||
|
||||
// triggerStreamTimeoutError 触发流超时错误状态
|
||||
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
|
||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// 重置超时计数
|
||||
if s.timeoutCounterCache != nil {
|
||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
|
||||
return true
|
||||
}
|
||||
|
||||
121
backend/internal/service/ratelimit_service_401_test.go
Normal file
121
backend/internal/service/ratelimit_service_401_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type rateLimitAccountRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
tempCalls int
|
||||
lastErrorMsg string
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
r.lastErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
type tokenCacheInvalidatorRecorder struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
r.accounts = append(r.accounts, account)
|
||||
return r.err
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
}
|
||||
@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo RedeemCodeRepository
|
||||
userRepo UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
cache RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
redeemRepo RedeemCodeRepository
|
||||
userRepo UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
cache RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
@@ -84,14 +85,16 @@ func NewRedeemService(
|
||||
cache RedeemCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
|
||||
// invalidateRedeemCaches 失效兑换相关的缓存
|
||||
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch redeemCode.Type {
|
||||
case RedeemTypeBalance:
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
case RedeemTypeConcurrency:
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
case RedeemTypeSubscription:
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
if redeemCode.GroupID != nil {
|
||||
groupID := *redeemCode.GroupID
|
||||
go func() {
|
||||
|
||||
68
backend/internal/service/scheduler_cache.go
Normal file
68
backend/internal/service/scheduler_cache.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
SchedulerModeSingle = "single"
|
||||
SchedulerModeMixed = "mixed"
|
||||
SchedulerModeForced = "forced"
|
||||
)
|
||||
|
||||
type SchedulerBucket struct {
|
||||
GroupID int64
|
||||
Platform string
|
||||
Mode string
|
||||
}
|
||||
|
||||
func (b SchedulerBucket) String() string {
|
||||
return fmt.Sprintf("%d:%s:%s", b.GroupID, b.Platform, b.Mode)
|
||||
}
|
||||
|
||||
func ParseSchedulerBucket(raw string) (SchedulerBucket, bool) {
|
||||
parts := strings.Split(raw, ":")
|
||||
if len(parts) != 3 {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
groupID, err := strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
if parts[1] == "" || parts[2] == "" {
|
||||
return SchedulerBucket{}, false
|
||||
}
|
||||
return SchedulerBucket{
|
||||
GroupID: groupID,
|
||||
Platform: parts[1],
|
||||
Mode: parts[2],
|
||||
}, true
|
||||
}
|
||||
|
||||
// SchedulerCache 负责调度快照与账号快照的缓存读写。
|
||||
type SchedulerCache interface {
|
||||
// GetSnapshot 读取快照并返回命中与否(ready + active + 数据完整)。
|
||||
GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error)
|
||||
// SetSnapshot 写入快照并切换激活版本。
|
||||
SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error
|
||||
// GetAccount 获取单账号快照。
|
||||
GetAccount(ctx context.Context, accountID int64) (*Account, error)
|
||||
// SetAccount 写入单账号快照(包含不可调度状态)。
|
||||
SetAccount(ctx context.Context, account *Account) error
|
||||
// DeleteAccount 删除单账号快照。
|
||||
DeleteAccount(ctx context.Context, accountID int64) error
|
||||
// UpdateLastUsed 批量更新账号的最后使用时间。
|
||||
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
// TryLockBucket 尝试获取分桶重建锁。
|
||||
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
|
||||
// ListBuckets 返回已注册的分桶集合。
|
||||
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
|
||||
// GetOutboxWatermark 读取 outbox 水位。
|
||||
GetOutboxWatermark(ctx context.Context) (int64, error)
|
||||
// SetOutboxWatermark 保存 outbox 水位。
|
||||
SetOutboxWatermark(ctx context.Context, id int64) error
|
||||
}
|
||||
10
backend/internal/service/scheduler_events.go
Normal file
10
backend/internal/service/scheduler_events.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package service
|
||||
|
||||
const (
|
||||
SchedulerOutboxEventAccountChanged = "account_changed"
|
||||
SchedulerOutboxEventAccountGroupsChanged = "account_groups_changed"
|
||||
SchedulerOutboxEventAccountBulkChanged = "account_bulk_changed"
|
||||
SchedulerOutboxEventAccountLastUsed = "account_last_used"
|
||||
SchedulerOutboxEventGroupChanged = "group_changed"
|
||||
SchedulerOutboxEventFullRebuild = "full_rebuild"
|
||||
)
|
||||
21
backend/internal/service/scheduler_outbox.go
Normal file
21
backend/internal/service/scheduler_outbox.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SchedulerOutboxEvent struct {
|
||||
ID int64
|
||||
EventType string
|
||||
AccountID *int64
|
||||
GroupID *int64
|
||||
Payload map[string]any
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// SchedulerOutboxRepository 提供调度 outbox 的读取接口。
|
||||
type SchedulerOutboxRepository interface {
|
||||
ListAfter(ctx context.Context, afterID int64, limit int) ([]SchedulerOutboxEvent, error)
|
||||
MaxID(ctx context.Context) (int64, error)
|
||||
}
|
||||
786
backend/internal/service/scheduler_snapshot_service.go
Normal file
786
backend/internal/service/scheduler_snapshot_service.go
Normal file
@@ -0,0 +1,786 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSchedulerCacheNotReady = errors.New("scheduler cache not ready")
|
||||
ErrSchedulerFallbackLimited = errors.New("scheduler db fallback limited")
|
||||
)
|
||||
|
||||
const outboxEventTimeout = 2 * time.Minute
|
||||
|
||||
type SchedulerSnapshotService struct {
|
||||
cache SchedulerCache
|
||||
outboxRepo SchedulerOutboxRepository
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cfg *config.Config
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
fallbackLimit *fallbackLimiter
|
||||
lagMu sync.Mutex
|
||||
lagFailures int
|
||||
}
|
||||
|
||||
func NewSchedulerSnapshotService(
|
||||
cache SchedulerCache,
|
||||
outboxRepo SchedulerOutboxRepository,
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cfg *config.Config,
|
||||
) *SchedulerSnapshotService {
|
||||
maxQPS := 0
|
||||
if cfg != nil {
|
||||
maxQPS = cfg.Gateway.Scheduling.DbFallbackMaxQPS
|
||||
}
|
||||
return &SchedulerSnapshotService{
|
||||
cache: cache,
|
||||
outboxRepo: outboxRepo,
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
fallbackLimit: newFallbackLimiter(maxQPS),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) Start() {
|
||||
if s == nil || s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runInitialRebuild()
|
||||
}()
|
||||
|
||||
interval := s.outboxPollInterval()
|
||||
if s.outboxRepo != nil && interval > 0 {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runOutboxWorker(interval)
|
||||
}()
|
||||
}
|
||||
|
||||
fullInterval := s.fullRebuildInterval()
|
||||
if fullInterval > 0 {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.runFullRebuildWorker(fullInterval)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
mode := s.resolveMode(platform, hasForcePlatform)
|
||||
bucket := s.bucketFor(groupID, platform, mode)
|
||||
|
||||
if s.cache != nil {
|
||||
cached, hit, err := s.cache.GetSnapshot(ctx, bucket)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err)
|
||||
} else if hit {
|
||||
return derefAccounts(cached), useMixed, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.guardFallback(ctx); err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(fallbackCtx, bucket, useMixed)
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if s.cache != nil {
|
||||
account, err := s.cache.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err)
|
||||
} else if account != nil {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.guardFallback(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fallbackCtx, cancel := s.withFallbackTimeout(ctx)
|
||||
defer cancel()
|
||||
return s.accountRepo.GetByID(fallbackCtx, accountID)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runInitialRebuild() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil {
|
||||
log.Printf("[Scheduler] rebuild startup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runOutboxWorker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.pollOutbox()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.pollOutbox()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.triggerFullRebuild("interval"); err != nil {
|
||||
log.Printf("[Scheduler] full rebuild failed: %v", err)
|
||||
}
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) pollOutbox() {
|
||||
if s.outboxRepo == nil || s.cache == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
watermark, err := s.cache.GetOutboxWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
events, err := s.outboxRepo.ListAfter(ctx, watermark, 200)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox poll failed: %v", err)
|
||||
return
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
watermarkForCheck := watermark
|
||||
for _, event := range events {
|
||||
eventCtx, cancel := context.WithTimeout(context.Background(), outboxEventTimeout)
|
||||
err := s.handleOutboxEvent(eventCtx, event)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
lastID := events[len(events)-1].ID
|
||||
if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil {
|
||||
log.Printf("[Scheduler] outbox watermark write failed: %v", err)
|
||||
} else {
|
||||
watermarkForCheck = lastID
|
||||
}
|
||||
|
||||
s.checkOutboxLag(ctx, events[0], watermarkForCheck)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent) error {
|
||||
switch event.EventType {
|
||||
case SchedulerOutboxEventAccountLastUsed:
|
||||
return s.handleLastUsedEvent(ctx, event.Payload)
|
||||
case SchedulerOutboxEventAccountBulkChanged:
|
||||
return s.handleBulkAccountEvent(ctx, event.Payload)
|
||||
case SchedulerOutboxEventAccountGroupsChanged:
|
||||
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
|
||||
case SchedulerOutboxEventAccountChanged:
|
||||
return s.handleAccountEvent(ctx, event.AccountID, event.Payload)
|
||||
case SchedulerOutboxEventGroupChanged:
|
||||
return s.handleGroupEvent(ctx, event.GroupID)
|
||||
case SchedulerOutboxEventFullRebuild:
|
||||
return s.triggerFullRebuild("outbox")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleLastUsedEvent(ctx context.Context, payload map[string]any) error {
|
||||
if s.cache == nil || payload == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := payload["last_used"].(map[string]any)
|
||||
if !ok || len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
updates := make(map[int64]time.Time, len(raw))
|
||||
for key, value := range raw {
|
||||
id, err := strconv.ParseInt(key, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
continue
|
||||
}
|
||||
sec, ok := toInt64(value)
|
||||
if !ok || sec <= 0 {
|
||||
continue
|
||||
}
|
||||
updates[id] = time.Unix(sec, 0)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.UpdateLastUsed(ctx, updates)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any) error {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
ids := parseInt64Slice(payload["account_ids"])
|
||||
for _, id := range ids {
|
||||
if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
|
||||
if accountID == nil || *accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
if s.accountRepo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var groupIDs []int64
|
||||
if payload != nil {
|
||||
groupIDs = parseInt64Slice(payload["group_ids"])
|
||||
}
|
||||
|
||||
account, err := s.accountRepo.GetByID(ctx, *accountID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAccountNotFound) {
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeleteAccount(ctx, *accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
if err := s.cache.SetAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
groupIDs = account.GroupIDs
|
||||
}
|
||||
return s.rebuildByAccount(ctx, account, groupIDs, "account_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64) error {
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupIDs := []int64{*groupID}
|
||||
return s.rebuildByGroupIDs(ctx, groupIDs, "group_change")
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string) error {
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
groupIDs = s.normalizeGroupIDs(groupIDs)
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string) error {
|
||||
groupIDs = s.normalizeGroupIDs(groupIDs)
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
var firstErr error
|
||||
for _, platform := range platforms {
|
||||
if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string) error {
|
||||
if platform == "" {
|
||||
return nil
|
||||
}
|
||||
var firstErr error
|
||||
for _, gid := range groupIDs {
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeSingle}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeForced}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeMixed}, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBuckets(ctx context.Context, buckets []SchedulerBucket, reason string) error {
|
||||
var firstErr error
|
||||
for _, bucket := range buckets {
|
||||
if err := s.rebuildBucket(ctx, bucket, reason); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket SchedulerBucket, reason string) error {
|
||||
if s.cache == nil {
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
ok, err := s.cache.TryLockBucket(ctx, bucket, 30*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil {
|
||||
log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error {
|
||||
if s.cache == nil {
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
buckets, err := s.cache.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] list buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
buckets, err = s.defaultBuckets(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] default buckets failed: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.rebuildBuckets(ctx, buckets, reason)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest SchedulerOutboxEvent, watermark int64) {
|
||||
if oldest.CreatedAt.IsZero() || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lag := time.Since(oldest.CreatedAt)
|
||||
if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 {
|
||||
log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds)
|
||||
}
|
||||
|
||||
if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds {
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures++
|
||||
failures := s.lagFailures
|
||||
s.lagMu.Unlock()
|
||||
|
||||
if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures {
|
||||
log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures)
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
if err := s.triggerFullRebuild("outbox_lag"); err != nil {
|
||||
log.Printf("[Scheduler] outbox lag rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.lagMu.Lock()
|
||||
s.lagFailures = 0
|
||||
s.lagMu.Unlock()
|
||||
}
|
||||
|
||||
threshold := s.cfg.Gateway.Scheduling.OutboxBacklogRebuildRows
|
||||
if threshold <= 0 || s.outboxRepo == nil {
|
||||
return
|
||||
}
|
||||
maxID, err := s.outboxRepo.MaxID(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if maxID-watermark >= int64(threshold) {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark)
|
||||
if err := s.triggerFullRebuild("outbox_backlog"); err != nil {
|
||||
log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucket SchedulerBucket, useMixed bool) ([]Account, error) {
|
||||
if s.accountRepo == nil {
|
||||
return nil, ErrSchedulerCacheNotReady
|
||||
}
|
||||
groupID := bucket.GroupID
|
||||
if s.isRunModeSimple() {
|
||||
groupID = 0
|
||||
}
|
||||
|
||||
if useMixed {
|
||||
platforms := []string{bucket.Platform, PlatformAntigravity}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID > 0 {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
if groupID > 0 {
|
||||
return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform)
|
||||
}
|
||||
return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket {
|
||||
return SchedulerBucket{
|
||||
GroupID: s.normalizeGroupID(groupID),
|
||||
Platform: platform,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) normalizeGroupID(groupID *int64) int64 {
|
||||
if s.isRunModeSimple() {
|
||||
return 0
|
||||
}
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) normalizeGroupIDs(groupIDs []int64) []int64 {
|
||||
if s.isRunModeSimple() {
|
||||
return []int64{0}
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return []int64{0}
|
||||
}
|
||||
seen := make(map[int64]struct{}, len(groupIDs))
|
||||
out := make([]int64, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []int64{0}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) resolveMode(platform string, hasForcePlatform bool) string {
|
||||
if hasForcePlatform {
|
||||
return SchedulerModeForced
|
||||
}
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
return SchedulerModeMixed
|
||||
}
|
||||
return SchedulerModeSingle
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) guardFallback(ctx context.Context) error {
|
||||
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackEnabled {
|
||||
if s.fallbackLimit == nil || s.fallbackLimit.Allow() {
|
||||
return nil
|
||||
}
|
||||
return ErrSchedulerFallbackLimited
|
||||
}
|
||||
return ErrSchedulerCacheNotReady
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) withFallbackTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if s.cfg == nil || s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
timeout := time.Duration(s.cfg.Gateway.Scheduling.DbFallbackTimeoutSeconds) * time.Second
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
if remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) isRunModeSimple() bool {
|
||||
return s.cfg != nil && s.cfg.RunMode == config.RunModeSimple
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) outboxPollInterval() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return time.Second
|
||||
}
|
||||
sec := s.cfg.Gateway.Scheduling.OutboxPollIntervalSeconds
|
||||
if sec <= 0 {
|
||||
return time.Second
|
||||
}
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
sec := s.cfg.Gateway.Scheduling.FullRebuildIntervalSeconds
|
||||
if sec <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(sec) * time.Second
|
||||
}
|
||||
|
||||
func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) {
|
||||
buckets := make([]SchedulerBucket, 0)
|
||||
platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity}
|
||||
for _, platform := range platforms {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced})
|
||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeMixed})
|
||||
}
|
||||
}
|
||||
|
||||
if s.isRunModeSimple() || s.groupRepo == nil {
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
for _, group := range groups {
|
||||
if group.Platform == "" {
|
||||
continue
|
||||
}
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeSingle})
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeForced})
|
||||
if group.Platform == PlatformAnthropic || group.Platform == PlatformGemini {
|
||||
buckets = append(buckets, SchedulerBucket{GroupID: group.ID, Platform: group.Platform, Mode: SchedulerModeMixed})
|
||||
}
|
||||
}
|
||||
return dedupeBuckets(buckets), nil
|
||||
}
|
||||
|
||||
func dedupeBuckets(in []SchedulerBucket) []SchedulerBucket {
|
||||
seen := make(map[string]struct{}, len(in))
|
||||
out := make([]SchedulerBucket, 0, len(in))
|
||||
for _, bucket := range in {
|
||||
key := bucket.String()
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, bucket)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func derefAccounts(accounts []*Account) []Account {
|
||||
if len(accounts) == 0 {
|
||||
return []Account{}
|
||||
}
|
||||
out := make([]Account, 0, len(accounts))
|
||||
for _, account := range accounts {
|
||||
if account == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, *account)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseInt64Slice(value any) []int64 {
|
||||
raw, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]int64, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if v, ok := toInt64(item); ok && v > 0 {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toInt64(value any) (int64, bool) {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case json.Number:
|
||||
parsed, err := strconv.ParseInt(v.String(), 10, 64)
|
||||
return parsed, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
type fallbackLimiter struct {
|
||||
maxQPS int
|
||||
mu sync.Mutex
|
||||
window time.Time
|
||||
count int
|
||||
}
|
||||
|
||||
func newFallbackLimiter(maxQPS int) *fallbackLimiter {
|
||||
if maxQPS <= 0 {
|
||||
return nil
|
||||
}
|
||||
return &fallbackLimiter{
|
||||
maxQPS: maxQPS,
|
||||
window: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *fallbackLimiter) Allow() bool {
|
||||
if l == nil || l.maxQPS <= 0 {
|
||||
return true
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if now.Sub(l.window) >= time.Second {
|
||||
l.window = now
|
||||
l.count = 0
|
||||
}
|
||||
if l.count >= l.maxQPS {
|
||||
return false
|
||||
}
|
||||
l.count++
|
||||
return true
|
||||
}
|
||||
63
backend/internal/service/session_limit_cache.go
Normal file
63
backend/internal/service/session_limit_cache.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionLimitCache 管理账号级别的活跃会话跟踪
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
|
||||
//
|
||||
// Key 格式: session_limit:account:{accountID}
|
||||
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
|
||||
//
|
||||
// 会话在空闲超时后自动过期,无需手动清理
|
||||
type SessionLimitCache interface {
|
||||
// RegisterSession 注册会话活动
|
||||
// - 如果会话已存在,刷新其时间戳并返回 true
|
||||
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
|
||||
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
|
||||
//
|
||||
// 参数:
|
||||
// accountID: 账号 ID
|
||||
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
|
||||
// maxSessions: 最大并发会话数限制
|
||||
// idleTimeout: 会话空闲超时时间
|
||||
//
|
||||
// 返回:
|
||||
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
// error: 操作错误
|
||||
RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
|
||||
|
||||
// RefreshSession 刷新现有会话的时间戳
|
||||
// 用于活跃会话保持活动状态
|
||||
RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
|
||||
|
||||
// GetActiveSessionCount 获取当前活跃会话数
|
||||
// 返回未过期的会话数量
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
// ========== 5h窗口费用缓存 ==========
|
||||
// Key 格式: window_cost:account:{accountID}
|
||||
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
|
||||
|
||||
// GetWindowCost 获取缓存的窗口费用
|
||||
// 返回 (cost, true, nil) 如果缓存命中
|
||||
// 返回 (0, false, nil) 如果缓存未命中
|
||||
// 返回 (0, false, err) 如果发生错误
|
||||
GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
|
||||
|
||||
// SetWindowCost 设置窗口费用缓存
|
||||
SetWindowCost(ctx context.Context, accountID int64, cost float64) error
|
||||
|
||||
// GetWindowCostBatch 批量获取窗口费用缓存
|
||||
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
|
||||
GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user