merge: sync upstream changes
This commit is contained in:
@@ -9,16 +9,19 @@ import (
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
|
||||
// BillingRateMultiplier 返回账号计费倍率。
|
||||
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
|
||||
// - 允许 0,表示该账号计费为 0
|
||||
// - 负数属于非法数据,出于安全考虑按 1.0 处理
|
||||
func (a *Account) BillingRateMultiplier() float64 {
|
||||
if a == nil || a.RateMultiplier == nil {
|
||||
return 1.0
|
||||
}
|
||||
if *a.RateMultiplier < 0 {
|
||||
return 1.0
|
||||
}
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
@@ -556,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WindowCostSchedulability 窗口费用调度状态
|
||||
type WindowCostSchedulability int
|
||||
|
||||
const (
|
||||
// WindowCostSchedulable 可正常调度
|
||||
WindowCostSchedulable WindowCostSchedulability = iota
|
||||
// WindowCostStickyOnly 仅允许粘性会话
|
||||
WindowCostStickyOnly
|
||||
// WindowCostNotSchedulable 完全不可调度
|
||||
WindowCostNotSchedulable
|
||||
)
|
||||
|
||||
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
|
||||
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
|
||||
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["window_cost_limit"]; ok {
|
||||
return parseExtraFloat64(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
|
||||
// 默认值为 10
|
||||
func (a *Account) GetWindowCostStickyReserve() float64 {
|
||||
if a.Extra == nil {
|
||||
return 10.0
|
||||
}
|
||||
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
|
||||
val := parseExtraFloat64(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 10.0
|
||||
}
|
||||
|
||||
// GetMaxSessions 获取最大并发会话数
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetMaxSessions() int {
|
||||
if a.Extra == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := a.Extra["max_sessions"]; ok {
|
||||
return parseExtraInt(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
|
||||
// 默认值为 5 分钟
|
||||
func (a *Account) GetSessionIdleTimeoutMinutes() int {
|
||||
if a.Extra == nil {
|
||||
return 5
|
||||
}
|
||||
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
|
||||
val := parseExtraInt(v)
|
||||
if val > 0 {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 5
|
||||
}
|
||||
|
||||
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
|
||||
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
|
||||
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
|
||||
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
|
||||
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
|
||||
limit := a.GetWindowCostLimit()
|
||||
if limit <= 0 {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
if currentWindowCost < limit {
|
||||
return WindowCostSchedulable
|
||||
}
|
||||
|
||||
stickyReserve := a.GetWindowCostStickyReserve()
|
||||
if currentWindowCost < limit+stickyReserve {
|
||||
return WindowCostStickyOnly
|
||||
}
|
||||
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
if f, err := v.Float64(); err == nil {
|
||||
return f
|
||||
}
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseExtraInt 从 extra 字段解析 int 值
|
||||
func parseExtraInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
|
||||
var a Account
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
|
||||
require.Nil(t, a.RateMultiplier)
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
|
||||
v := 0.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 0.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
|
||||
v := -1.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
@@ -50,11 +50,13 @@ type AccountRepository interface {
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||
ClearModelRateLimits(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
@@ -63,14 +65,15 @@ type AccountRepository interface {
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
|
||||
@@ -143,6 +143,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
|
||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
panic("unexpected SetModelRateLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
@@ -163,6 +167,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
|
||||
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearModelRateLimits call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
panic("unexpected UpdateSessionWindow call")
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
|
||||
}
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
//
|
||||
// cost: 账号口径费用(total_cost * account_rate_multiplier)
|
||||
// standard_cost: 标准费用(total_cost,不含倍率)
|
||||
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
|
||||
type WindowStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
StandardCost float64 `json:"standard_cost"`
|
||||
UserCost float64 `json:"user_cost"`
|
||||
}
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
windowStats = &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}
|
||||
|
||||
// 缓存窗口统计(1 分钟)
|
||||
@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
||||
}
|
||||
|
||||
return &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
|
||||
// 用于账号列表页面显示当前窗口费用
|
||||
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,8 @@ type AdminService interface {
|
||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
||||
DeleteProxy(ctx context.Context, id int64) error
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
|
||||
BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||
|
||||
@@ -105,6 +106,9 @@ type CreateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -124,6 +128,9 @@ type UpdateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -136,6 +143,7 @@ type CreateAccountInput struct {
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
GroupIDs []int64
|
||||
ExpiresAt *int64
|
||||
AutoPauseOnExpired *bool
|
||||
@@ -151,8 +159,9 @@ type UpdateAccountInput struct {
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
ExpiresAt *int64
|
||||
@@ -162,16 +171,17 @@ type UpdateAccountInput struct {
|
||||
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
type BulkUpdateAccountsInput struct {
|
||||
AccountIDs []int64
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
AccountIDs []int64
|
||||
Name string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
|
||||
Status string
|
||||
Schedulable *bool
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
@@ -220,23 +230,35 @@ type GenerateRedeemCodesInput struct {
|
||||
ValidityDays int // 订阅类型专用:有效天数
|
||||
}
|
||||
|
||||
// ProxyTestResult represents the result of testing a proxy
|
||||
type ProxyTestResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
type ProxyBatchDeleteResult struct {
|
||||
DeletedIDs []int64 `json:"deleted_ids"`
|
||||
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
||||
type ProxyBatchDeleteSkipped struct {
|
||||
ID int64 `json:"id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// ProxyTestResult represents the result of testing a proxy
|
||||
type ProxyTestResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
}
|
||||
|
||||
// ProxyExitInfo represents proxy exit information from ip-api.com
|
||||
type ProxyExitInfo struct {
|
||||
IP string
|
||||
City string
|
||||
Region string
|
||||
Country string
|
||||
IP string
|
||||
City string
|
||||
Region string
|
||||
Country string
|
||||
CountryCode string
|
||||
}
|
||||
|
||||
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||
@@ -254,6 +276,7 @@ type adminServiceImpl struct {
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
@@ -267,6 +290,7 @@ func NewAdminService(
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
@@ -278,6 +302,7 @@ func NewAdminService(
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
@@ -562,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -690,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
}
|
||||
|
||||
// 模型路由配置
|
||||
if input.ModelRouting != nil {
|
||||
group.ModelRouting = input.ModelRouting
|
||||
}
|
||||
if input.ModelRoutingEnabled != nil {
|
||||
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -817,6 +851,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier < 0 {
|
||||
return nil, errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -869,6 +909,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Priority != nil {
|
||||
account.Priority = *input.Priority
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier < 0 {
|
||||
return nil, errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
account.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.Status != "" {
|
||||
account.Status = input.Status
|
||||
}
|
||||
@@ -942,6 +988,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
}
|
||||
}
|
||||
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier < 0 {
|
||||
return nil, errors.New("rate_multiplier must be >= 0")
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
@@ -959,6 +1011,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
if input.Priority != nil {
|
||||
repoUpdates.Priority = input.Priority
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
repoUpdates.RateMultiplier = input.RateMultiplier
|
||||
}
|
||||
if input.Status != "" {
|
||||
repoUpdates.Status = &input.Status
|
||||
}
|
||||
@@ -1069,6 +1124,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
s.attachProxyLatency(ctx, proxies)
|
||||
return proxies, result.Total, nil
|
||||
}
|
||||
|
||||
@@ -1077,7 +1133,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||
return s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||
proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.attachProxyLatency(ctx, proxies)
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
||||
@@ -1097,6 +1158,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Probe latency asynchronously so creation isn't blocked by network timeout.
|
||||
go s.probeProxyLatency(context.Background(), proxy)
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
@@ -1135,12 +1198,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
||||
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return ErrProxyInUse
|
||||
}
|
||||
return s.proxyRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
|
||||
// Return mock data for now - would need a dedicated repository method
|
||||
return []Account{}, 0, nil
|
||||
func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
|
||||
result := &ProxyBatchDeleteResult{}
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
|
||||
if err != nil {
|
||||
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||
ID: id,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if count > 0 {
|
||||
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||
ID: id,
|
||||
Reason: ErrProxyInUse.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if err := s.proxyRepo.Delete(ctx, id); err != nil {
|
||||
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
|
||||
ID: id,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
result.DeletedIDs = append(result.DeletedIDs, id)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
@@ -1240,23 +1344,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
proxyURL := proxy.URL()
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||
if err != nil {
|
||||
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||
Success: false,
|
||||
Message: err.Error(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
return &ProxyTestResult{
|
||||
Success: false,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
latency := latencyMs
|
||||
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
|
||||
Success: true,
|
||||
LatencyMs: &latency,
|
||||
Message: "Proxy is accessible",
|
||||
IPAddress: exitInfo.IP,
|
||||
Country: exitInfo.Country,
|
||||
CountryCode: exitInfo.CountryCode,
|
||||
Region: exitInfo.Region,
|
||||
City: exitInfo.City,
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
return &ProxyTestResult{
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: exitInfo.IP,
|
||||
City: exitInfo.City,
|
||||
Region: exitInfo.Region,
|
||||
Country: exitInfo.Country,
|
||||
Success: true,
|
||||
Message: "Proxy is accessible",
|
||||
LatencyMs: latencyMs,
|
||||
IPAddress: exitInfo.IP,
|
||||
City: exitInfo.City,
|
||||
Region: exitInfo.Region,
|
||||
Country: exitInfo.Country,
|
||||
CountryCode: exitInfo.CountryCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
|
||||
if s.proxyProber == nil || proxy == nil {
|
||||
return
|
||||
}
|
||||
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
|
||||
if err != nil {
|
||||
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||
Success: false,
|
||||
Message: err.Error(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
latency := latencyMs
|
||||
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
|
||||
Success: true,
|
||||
LatencyMs: &latency,
|
||||
Message: "Proxy is accessible",
|
||||
IPAddress: exitInfo.IP,
|
||||
Country: exitInfo.Country,
|
||||
CountryCode: exitInfo.CountryCode,
|
||||
Region: exitInfo.Region,
|
||||
City: exitInfo.City,
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
||||
// 如果存在混合,返回错误提示用户确认
|
||||
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
@@ -1306,6 +1456,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
|
||||
if s.proxyLatencyCache == nil || len(proxies) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
ids = append(ids, proxies[i].ID)
|
||||
}
|
||||
|
||||
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
|
||||
if err != nil {
|
||||
log.Printf("Warning: load proxy latency cache failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := range proxies {
|
||||
info := latencies[proxies[i].ID]
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
if info.Success {
|
||||
proxies[i].LatencyStatus = "success"
|
||||
proxies[i].LatencyMs = info.LatencyMs
|
||||
} else {
|
||||
proxies[i].LatencyStatus = "failed"
|
||||
}
|
||||
proxies[i].LatencyMessage = info.Message
|
||||
proxies[i].IPAddress = info.IPAddress
|
||||
proxies[i].Country = info.Country
|
||||
proxies[i].CountryCode = info.CountryCode
|
||||
proxies[i].Region = info.Region
|
||||
proxies[i].City = info.City
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
|
||||
if s.proxyLatencyCache == nil || info == nil {
|
||||
return
|
||||
}
|
||||
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
|
||||
log.Printf("Warning: store proxy latency cache failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
||||
func getAccountPlatform(accountPlatform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
|
||||
type accountRepoStubForBulkUpdate struct {
|
||||
accountRepoStub
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
|
||||
@@ -153,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
deletedIDs []int64
|
||||
deleteErr error
|
||||
countErr error
|
||||
accountCount int64
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
|
||||
@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
panic("unexpected CountAccountsByProxyID call")
|
||||
if s.countErr != nil {
|
||||
return 0, s.countErr
|
||||
}
|
||||
return s.accountCount, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
panic("unexpected ListAccountSummariesByProxyID call")
|
||||
}
|
||||
|
||||
type redeemRepoStub struct {
|
||||
@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
|
||||
require.Equal(t, []int64{404}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
|
||||
repo := &proxyRepoStub{accountCount: 2}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 77)
|
||||
require.ErrorIs(t, err, ErrProxyInUse)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &proxyRepoStub{deleteErr: deleteErr}
|
||||
|
||||
@@ -564,6 +564,10 @@ urlFallbackLoop:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -574,6 +578,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@@ -615,6 +620,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
@@ -645,6 +651,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
@@ -697,6 +704,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "signature_error",
|
||||
@@ -740,6 +748,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "signature_retry_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||
@@ -770,6 +779,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: retryResp.StatusCode,
|
||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||
Kind: kind,
|
||||
@@ -817,6 +827,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
@@ -1371,6 +1382,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@@ -1412,6 +1424,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
@@ -1442,6 +1455,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
@@ -1543,6 +1557,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
@@ -1559,6 +1574,7 @@ urlFallbackLoop:
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "http_error",
|
||||
@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: upstreamStatus,
|
||||
UpstreamRequestID: upstreamRequestID,
|
||||
Kind: "http_error",
|
||||
|
||||
@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
if !a.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
if a.isModelRateLimited(requestedModel) {
|
||||
return false
|
||||
}
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("not an antigravity oauth account")
|
||||
}
|
||||
|
||||
cacheKey := antigravityTokenCacheKey(account)
|
||||
cacheKey := AntigravityTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func antigravityTokenCacheKey(account *Account) string {
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "ag:" + projectID
|
||||
|
||||
@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -207,20 +207,22 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -248,21 +250,23 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
|
||||
@@ -172,12 +172,16 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
Concurrency: 3,
|
||||
},
|
||||
Group: &APIKeyAuthGroupSnapshot{
|
||||
ID: groupID,
|
||||
Name: "g",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
ID: groupID,
|
||||
Name: "g",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-opus-*": {1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
require.Equal(t, int64(1), apiKey.ID)
|
||||
require.Equal(t, int64(2), apiKey.User.ID)
|
||||
require.Equal(t, groupID, apiKey.Group.ID)
|
||||
require.True(t, apiKey.Group.ModelRoutingEnabled)
|
||||
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
|
||||
208
backend/internal/service/claude_token_provider.go
Normal file
208
backend/internal/service/claude_token_provider.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeTokenRefreshSkew = 3 * time.Minute
|
||||
claudeTokenCacheSkew = 5 * time.Minute
|
||||
claudeLockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
}
|
||||
|
||||
func NewClaudeTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache ClaudeTokenCache,
|
||||
oauthService *OAuthService,
|
||||
) *ClaudeTokenProvider {
|
||||
return &ClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
// 构建新 credentials,保留原有字段
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
939
backend/internal/service/claude_token_provider_test.go
Normal file
939
backend/internal/service/claude_token_provider_test.go
Normal file
@@ -0,0 +1,939 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// claudeTokenCacheStub implements ClaudeTokenCache for testing
|
||||
type claudeTokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
|
||||
return &claudeTokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
|
||||
type claudeAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// claudeOAuthServiceStub implements OAuthService methods for testing
|
||||
type claudeOAuthServiceStub struct {
|
||||
tokenInfo *TokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
// testClaudeTokenProvider is a test version that uses the stub OAuth service
|
||||
type testClaudeTokenProvider struct {
|
||||
accountRepo *claudeAccountRepoStub
|
||||
tokenCache *claudeTokenCacheStub
|
||||
oauthService *claudeOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// Build new credentials
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > claudeTokenCacheSkew {
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
getErr: errors.New("db connection failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 113,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still work, just using the passed-in account
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
updateErr: errors.New("db write failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 114,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still return token even if update fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-access-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
"custom_field": "should-be-preserved",
|
||||
"organization": "test-org",
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new-access-token", token)
|
||||
|
||||
// Verify existing fields are preserved
|
||||
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
|
||||
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
|
||||
// Verify new fields are updated
|
||||
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
|
||||
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 116,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// After lock is acquired, cache should have the token (simulating another worker)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other-worker"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set)
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
// After lock wait, return token from credentials
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cacheKey := "claude:account:303"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 303,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 304,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 305,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 306,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
|
||||
@@ -142,6 +142,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -157,6 +160,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1046,13 +1052,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil, // No concurrency service
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
sessionHash := "sticky"
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{sessionHash: 1},
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-a": {1},
|
||||
"claude-b": {2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil, // legacy path
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
|
||||
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
|
||||
})
|
||||
|
||||
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
@@ -1077,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1109,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1143,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1179,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1206,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
@@ -1238,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1271,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
@@ -1341,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
@@ -1398,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -40,6 +41,21 @@ const (
|
||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||
)
|
||||
|
||||
func (s *GatewayService) debugModelRoutingEnabled() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func shortSessionHash(sessionHash string) string {
|
||||
if sessionHash == "" {
|
||||
return ""
|
||||
}
|
||||
if len(sessionHash) <= 8 {
|
||||
return sessionHash
|
||||
}
|
||||
return sessionHash[:8]
|
||||
}
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
@@ -196,6 +212,8 @@ type GatewayService struct {
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
claudeTokenProvider *ClaudeTokenProvider
|
||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -215,6 +233,8 @@ func NewGatewayService(
|
||||
identityService *IdentityService,
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
claudeTokenProvider *ClaudeTokenProvider,
|
||||
sessionLimitCache SessionLimitCache,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -232,6 +252,8 @@ func NewGatewayService(
|
||||
identityService: identityService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
claudeTokenProvider: claudeTokenProvider,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -797,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
// 提取会话 UUID(用于会话数量限制)
|
||||
sessionUUID := extractSessionUUID(metadataUserID)
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
@@ -813,6 +839,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||
groupPlatform := ""
|
||||
if group != nil {
|
||||
groupPlatform = group.Platform
|
||||
}
|
||||
log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
|
||||
derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
@@ -856,6 +891,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, err
|
||||
}
|
||||
preferOAuth := platform == PlatformGemini
|
||||
if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
@@ -873,28 +911,242 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return excluded
|
||||
}
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
|
||||
// 获取模型路由配置(仅 anthropic 平台)
|
||||
var routingAccountIDs []int64
|
||||
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||
routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
|
||||
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
|
||||
if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
|
||||
keys := make([]string, 0, len(group.ModelRouting))
|
||||
for k := range group.ModelRouting {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
const maxKeys = 20
|
||||
if len(keys) > maxKeys {
|
||||
keys = keys[:maxKeys]
|
||||
}
|
||||
log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
|
||||
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
|
||||
// 1. 过滤出路由列表中可调度的账号
|
||||
var routingCandidates []*Account
|
||||
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
|
||||
for _, routingAccountID := range routingAccountIDs {
|
||||
if isExcluded(routingAccountID) {
|
||||
filteredExcluded++
|
||||
continue
|
||||
}
|
||||
account, ok := accountByID[routingAccountID]
|
||||
if !ok || !account.IsSchedulable() {
|
||||
if !ok {
|
||||
filteredMissing++
|
||||
} else {
|
||||
filteredUnsched++
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !s.isAccountAllowedForPlatform(account, platform, useMixed) {
|
||||
filteredPlatform++
|
||||
continue
|
||||
}
|
||||
if !account.IsSchedulableForModel(requestedModel) {
|
||||
filteredModelScope++
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||
filteredModelMapping++
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
|
||||
filteredWindowCost++
|
||||
continue
|
||||
}
|
||||
routingCandidates = append(routingCandidates, account)
|
||||
}
|
||||
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
|
||||
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
|
||||
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
|
||||
}
|
||||
|
||||
if len(routingCandidates) > 0 {
|
||||
// 1.5. 在路由账号范围内检查粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
if stickyAccount.IsSchedulable() &&
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
stickyAccount.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 批量获取负载信息
|
||||
routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates))
|
||||
for _, acc := range routingCandidates {
|
||||
routingLoads = append(routingLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
|
||||
|
||||
// 3. 按负载感知排序
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var routingAvailable []accountWithLoad
|
||||
for _, acc := range routingCandidates {
|
||||
loadInfo := routingLoadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo})
|
||||
}
|
||||
}
|
||||
|
||||
if len(routingAvailable) > 0 {
|
||||
// 排序:优先级 > 负载率 > 最后使用时间
|
||||
sort.SliceStable(routingAvailable, func(i, j int) bool {
|
||||
a, b := routingAvailable[i], routingAvailable[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
|
||||
// 4. 尝试获取槽位
|
||||
for _, item := range routingAvailable {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
|
||||
acc := routingAvailable[0].account
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
|
||||
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
// 粘性命中仅在当前可调度候选集中生效。
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@@ -935,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
// 窗口费用检查(非粘性会话路径)
|
||||
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
@@ -952,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -1001,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
@@ -1030,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
@@ -1093,6 +1359,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||
return nil
|
||||
}
|
||||
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||
if err != nil || group == nil {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Preserve existing behavior: model routing only applies to anthropic groups.
|
||||
if group.Platform != PlatformAnthropic {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ids := group.GetRoutingAccountIDs(requestedModel)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
|
||||
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return nil, nil, nil
|
||||
@@ -1242,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示可调度,false 表示不可调度
|
||||
func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
limit := account.GetWindowCostLimit()
|
||||
if limit <= 0 {
|
||||
return true // 未启用窗口费用限制
|
||||
}
|
||||
|
||||
// 尝试从缓存获取窗口费用
|
||||
var currentCost float64
|
||||
if s.sessionLimitCache != nil {
|
||||
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
|
||||
currentCost = cost
|
||||
goto checkSchedulability
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
{
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
// 失败开放:查询失败时允许调度
|
||||
return true
|
||||
}
|
||||
|
||||
// 使用标准费用(不含账号倍率)
|
||||
currentCost = stats.StandardCost
|
||||
|
||||
// 设置缓存(忽略错误)
|
||||
if s.sessionLimitCache != nil {
|
||||
_ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
|
||||
}
|
||||
}
|
||||
|
||||
checkSchedulability:
|
||||
schedulability := account.CheckWindowCostSchedulability(currentCost)
|
||||
|
||||
switch schedulability {
|
||||
case WindowCostSchedulable:
|
||||
return true
|
||||
case WindowCostStickyOnly:
|
||||
return isSticky
|
||||
case WindowCostNotSchedulable:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
maxSessions := account.GetMaxSessions()
|
||||
if maxSessions <= 0 || sessionUUID == "" {
|
||||
return true // 未启用会话限制或无会话ID
|
||||
}
|
||||
|
||||
if s.sessionLimitCache == nil {
|
||||
return true // 缓存不可用时允许通过
|
||||
}
|
||||
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
|
||||
if err != nil {
|
||||
// 失败开放:缓存错误时允许通过
|
||||
return true
|
||||
}
|
||||
return allowed
|
||||
}
|
||||
|
||||
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
|
||||
// 格式: user_{64位hex}_account__session_{uuid}
|
||||
func extractSessionUUID(metadataUserID string) string {
|
||||
if metadataUserID == "" {
|
||||
return ""
|
||||
}
|
||||
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
@@ -1274,6 +1667,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
|
||||
// so switching model can switch upstream account within the same sticky session.
|
||||
if len(routingAccountIDs) > 0 {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||
}
|
||||
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Select an account from the routed candidates.
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
accountsLoaded = true
|
||||
|
||||
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||
for _, id := range routingAccountIDs {
|
||||
if id > 0 {
|
||||
routingSet[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, ok := routingSet[acc.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected != nil {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
@@ -1292,13 +1795,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
if !accountsLoaded {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
@@ -1364,6 +1870,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||
if len(routingAccountIDs) > 0 {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||
}
|
||||
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Select an account from the routed candidates.
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
accountsLoaded = true
|
||||
|
||||
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||
for _, id := range routingAccountIDs {
|
||||
if id > 0 {
|
||||
routingSet[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, ok := routingSet[acc.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected != nil {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
@@ -1385,9 +2000,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
if !accountsLoaded {
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
@@ -1488,6 +2106,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
}
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
|
||||
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
|
||||
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
@@ -1901,6 +2529,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1918,6 +2548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@@ -1942,6 +2573,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "signature_error",
|
||||
@@ -1993,6 +2625,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: retryResp.StatusCode,
|
||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||
Kind: "signature_retry_thinking",
|
||||
@@ -2021,6 +2654,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "signature_retry_tools_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
|
||||
@@ -2079,6 +2713,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry",
|
||||
@@ -2127,6 +2762,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "retry_exhausted_failover",
|
||||
@@ -2193,6 +2829,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover_on_400",
|
||||
@@ -3283,30 +3920,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
|
||||
@@ -545,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = idHeader
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
// In this code path `body` is already the JSON sent to upstream.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@@ -588,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "signature_error",
|
||||
@@ -662,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "retry",
|
||||
@@ -711,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
@@ -737,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "failover",
|
||||
@@ -972,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = idHeader
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
// In this code path `body` is already the JSON sent to upstream.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@@ -1036,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: upstreamReqID,
|
||||
Kind: "retry",
|
||||
@@ -1120,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
@@ -1143,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "failover",
|
||||
@@ -1168,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: requestID,
|
||||
Kind: "http_error",
|
||||
@@ -1300,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: upstreamStatus,
|
||||
UpstreamRequestID: upstreamRequestID,
|
||||
Kind: "http_error",
|
||||
|
||||
@@ -125,6 +125,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
|
||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -138,6 +141,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
|
||||
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
|
||||
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
||||
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
||||
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
||||
DeleteAccessToken(ctx context.Context, cacheKey string) error
|
||||
|
||||
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
||||
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("not a gemini oauth account")
|
||||
}
|
||||
|
||||
cacheKey := geminiTokenCacheKey(account)
|
||||
cacheKey := GeminiTokenCacheKey(account)
|
||||
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func geminiTokenCacheKey(account *Account) string {
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return projectID
|
||||
return "gemini:" + projectID
|
||||
}
|
||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
||||
return "gemini:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64
|
||||
@@ -27,6 +30,12 @@ type Group struct {
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
// 模型路由配置
|
||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||
// value: 优先账号 ID 列表
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
|
||||
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
|
||||
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
|
||||
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 精确匹配优先
|
||||
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
|
||||
return accountIDs
|
||||
}
|
||||
|
||||
// 2. 通配符匹配(前缀匹配)
|
||||
for pattern, accountIDs := range g.ModelRouting {
|
||||
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
|
||||
return accountIDs
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchModelPattern 检查模型是否匹配模式
|
||||
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
|
||||
func matchModelPattern(pattern, model string) bool {
|
||||
if pattern == model {
|
||||
return true
|
||||
}
|
||||
|
||||
// 处理 * 通配符(仅支持末尾通配符)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := strings.TrimSuffix(pattern, "*")
|
||||
return strings.HasPrefix(model, prefix)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
56
backend/internal/service/model_rate_limit.go
Normal file
56
backend/internal/service/model_rate_limit.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const modelRateLimitsKey = "model_rate_limits"
|
||||
const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
|
||||
|
||||
func resolveModelRateLimitScope(requestedModel string) (string, bool) {
|
||||
model := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
if strings.Contains(model, "sonnet") {
|
||||
return modelRateLimitScopeClaudeSonnet, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (a *Account) isModelRateLimited(requestedModel string) bool {
|
||||
scope, ok := resolveModelRateLimitScope(requestedModel)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
resetAt := a.modelRateLimitResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*resetAt)
|
||||
}
|
||||
|
||||
func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
|
||||
if a == nil || a.Extra == nil || scope == "" {
|
||||
return nil
|
||||
}
|
||||
rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawLimit, ok := rawLimits[scope].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
|
||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||
return nil
|
||||
}
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
@@ -93,6 +93,8 @@ type OpenAIGatewayService struct {
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
openAITokenProvider *OpenAITokenProvider
|
||||
toolCorrector *CodexToolCorrector
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
@@ -110,6 +112,7 @@ func NewOpenAIGatewayService(
|
||||
billingCacheService *BillingCacheService,
|
||||
httpUpstream HTTPUpstream,
|
||||
deferredService *DeferredService,
|
||||
openAITokenProvider *OpenAITokenProvider,
|
||||
) *OpenAIGatewayService {
|
||||
return &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -125,6 +128,8 @@ func NewOpenAIGatewayService(
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
deferredService: deferredService,
|
||||
openAITokenProvider: openAITokenProvider,
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case AccountTypeOAuth:
|
||||
// 使用 TokenProvider 获取缓存的 token
|
||||
if s.openAITokenProvider != nil {
|
||||
accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
// 降级:TokenProvider 未配置时直接从账号读取
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
@@ -664,6 +678,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
@@ -673,6 +692,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
@@ -707,6 +727,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
@@ -864,6 +885,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "http_error",
|
||||
@@ -894,6 +916,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: kind,
|
||||
@@ -1097,6 +1120,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||
data = correctedData
|
||||
line = "data: " + correctedData
|
||||
}
|
||||
|
||||
// 写入客户端(客户端断开后继续 drain 上游)
|
||||
if !clientDisconnected {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
@@ -1199,6 +1228,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
// correctToolCallsInResponseBody 修正响应体中的工具调用
|
||||
func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
|
||||
if changed {
|
||||
return []byte(corrected)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
@@ -1302,6 +1345,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
// Correct tool calls in final response
|
||||
body = s.correctToolCallsInResponseBody(body)
|
||||
} else {
|
||||
usage = s.parseSSEUsageFromBody(bodyText)
|
||||
if originalModel != mappedModel {
|
||||
@@ -1470,28 +1515,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 添加 UserAgent
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
|
||||
func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
|
||||
// 创建一个简单的 service 实例来测试工具修正
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
changed bool
|
||||
}{
|
||||
{
|
||||
name: "correct apply_patch in response body",
|
||||
input: []byte(`{
|
||||
"choices": [{
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"function": {"name": "apply_patch"}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
}`),
|
||||
expected: "edit",
|
||||
changed: true,
|
||||
},
|
||||
{
|
||||
name: "correct update_plan in response body",
|
||||
input: []byte(`{
|
||||
"tool_calls": [{
|
||||
"function": {"name": "update_plan"}
|
||||
}]
|
||||
}`),
|
||||
expected: "todowrite",
|
||||
changed: true,
|
||||
},
|
||||
{
|
||||
name: "no change for correct tool name",
|
||||
input: []byte(`{
|
||||
"tool_calls": [{
|
||||
"function": {"name": "edit"}
|
||||
}]
|
||||
}`),
|
||||
expected: "edit",
|
||||
changed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := service.correctToolCallsInResponseBody(tt.input)
|
||||
resultStr := string(result)
|
||||
|
||||
// 检查是否包含期望的工具名称
|
||||
if !strings.Contains(resultStr, tt.expected) {
|
||||
t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
|
||||
}
|
||||
|
||||
// 对于预期有变化的情况,验证结果与输入不同
|
||||
if tt.changed && string(result) == string(tt.input) {
|
||||
t.Error("expected result to be different from input, but they are the same")
|
||||
}
|
||||
|
||||
// 对于预期无变化的情况,验证结果与输入相同
|
||||
if !tt.changed && string(result) != string(tt.input) {
|
||||
t.Error("expected result to be same as input, but they are different")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
|
||||
func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
if service.toolCorrector == nil {
|
||||
t.Fatal("toolCorrector should not be nil")
|
||||
}
|
||||
|
||||
// 测试修正器可以正常工作
|
||||
data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
|
||||
corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||
|
||||
if !changed {
|
||||
t.Error("expected tool call to be corrected")
|
||||
}
|
||||
|
||||
if !strings.Contains(corrected, "edit") {
|
||||
t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolCorrectionStats 测试工具修正统计功能
|
||||
func TestToolCorrectionStats(t *testing.T) {
|
||||
service := &OpenAIGatewayService{
|
||||
toolCorrector: NewCodexToolCorrector(),
|
||||
}
|
||||
|
||||
// 执行几次修正
|
||||
testData := []string{
|
||||
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||
`{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
|
||||
`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
|
||||
}
|
||||
|
||||
for _, data := range testData {
|
||||
service.toolCorrector.CorrectToolCallsInSSEData(data)
|
||||
}
|
||||
|
||||
stats := service.toolCorrector.GetStats()
|
||||
|
||||
if stats.TotalCorrected != 3 {
|
||||
t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||
t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||
t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||
}
|
||||
}
|
||||
189
backend/internal/service/openai_token_provider.go
Normal file
189
backend/internal/service/openai_token_provider.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAITokenRefreshSkew = 3 * time.Minute
|
||||
openAITokenCacheSkew = 5 * time.Minute
|
||||
openAILockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type OpenAITokenCache = GeminiTokenCache
|
||||
|
||||
// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
|
||||
type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
func NewOpenAITokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache OpenAITokenCache,
|
||||
openAIOAuthService *OpenAIOAuthService,
|
||||
) *OpenAITokenProvider {
|
||||
return &OpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
openAIOAuthService: openAIOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("openai_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
if p.accountRepo != nil {
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
|
||||
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(openAILockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
810
backend/internal/service/openai_token_provider_test.go
Normal file
810
backend/internal/service/openai_token_provider_test.go
Normal file
@@ -0,0 +1,810 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// openAITokenCacheStub implements OpenAITokenCache for testing
|
||||
type openAITokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newOpenAITokenCacheStub() *openAITokenCacheStub {
|
||||
return &openAITokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
|
||||
type openAIAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
|
||||
type openAIOAuthServiceStub struct {
|
||||
tokenInfo *OpenAITokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
|
||||
now := time.Now()
|
||||
return map[string]any{
|
||||
"access_token": info.AccessToken,
|
||||
"refresh_token": info.RefreshToken,
|
||||
"expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
tokenInfo: &OpenAITokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// We need to directly test with the stub - create a custom provider
|
||||
customProvider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := customProvider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
// testOpenAITokenProvider is a test version that uses the stub OAuth service
|
||||
type testOpenAITokenProvider struct {
|
||||
accountRepo *openAIAccountRepoStub
|
||||
tokenCache *openAITokenCacheStub
|
||||
oauthService *openAIOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an openai oauth account")
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond) // Short wait for test
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > openAITokenCacheSkew {
|
||||
ttl = until - openAITokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get the token set by the "winner" or the original
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an openai oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
accountRepo := &openAIAccountRepoStub{}
|
||||
oauthService := &openAIOAuthServiceStub{
|
||||
tokenInfo: &OpenAITokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
|
||||
// Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
|
||||
originalGet := int32(0)
|
||||
cache.tokens[cacheKey] = "" // Empty initially
|
||||
|
||||
provider := &testOpenAITokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// In a goroutine, set the cached token after a small delay (simulating race)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get either the refreshed token or the cached one
|
||||
require.NotEmpty(t, token)
|
||||
_ = originalGet // Suppress unused warning
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// Should get either the fallback token or the refreshed one
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 201,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set) - should use credentials
|
||||
account := &Account{
|
||||
ID: 202,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
// Without OAuth service, refresh will fail but token should be returned from credentials
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cacheKey := "openai:account:203"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 203,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token) // Should fall back to credentials
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 204,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 205,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 206,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
307
backend/internal/service/openai_tool_corrector.go
Normal file
307
backend/internal/service/openai_tool_corrector.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
|
||||
var codexToolNameMapping = map[string]string{
|
||||
"apply_patch": "edit",
|
||||
"applyPatch": "edit",
|
||||
"update_plan": "todowrite",
|
||||
"updatePlan": "todowrite",
|
||||
"read_plan": "todoread",
|
||||
"readPlan": "todoread",
|
||||
"search_files": "grep",
|
||||
"searchFiles": "grep",
|
||||
"list_files": "glob",
|
||||
"listFiles": "glob",
|
||||
"read_file": "read",
|
||||
"readFile": "read",
|
||||
"write_file": "write",
|
||||
"writeFile": "write",
|
||||
"execute_bash": "bash",
|
||||
"executeBash": "bash",
|
||||
"exec_bash": "bash",
|
||||
"execBash": "bash",
|
||||
}
|
||||
|
||||
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
|
||||
type ToolCorrectionStats struct {
|
||||
TotalCorrected int `json:"total_corrected"`
|
||||
CorrectionsByTool map[string]int `json:"corrections_by_tool"`
|
||||
}
|
||||
|
||||
// CodexToolCorrector 处理 Codex 工具调用的自动修正
|
||||
type CodexToolCorrector struct {
|
||||
stats ToolCorrectionStats
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCodexToolCorrector 创建新的工具修正器
|
||||
func NewCodexToolCorrector() *CodexToolCorrector {
|
||||
return &CodexToolCorrector{
|
||||
stats: ToolCorrectionStats{
|
||||
CorrectionsByTool: make(map[string]int),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
|
||||
// 返回修正后的数据和是否进行了修正
|
||||
func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
|
||||
if data == "" || data == "\n" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// 尝试解析 JSON
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
// 不是有效的 JSON,直接返回原数据
|
||||
return data, false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 处理 tool_calls 数组
|
||||
if toolCalls, ok := payload["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 function_call 对象
|
||||
if functionCall, ok := payload["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 delta.tool_calls
|
||||
if delta, ok := payload["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
|
||||
if choices, ok := payload["choices"].([]any); ok {
|
||||
for _, choice := range choices {
|
||||
if choiceMap, ok := choice.(map[string]any); ok {
|
||||
// 处理 message 中的工具调用
|
||||
if message, ok := choiceMap["message"].(map[string]any); ok {
|
||||
if toolCalls, ok := message["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := message["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 处理 delta 中的工具调用
|
||||
if delta, ok := choiceMap["delta"].(map[string]any); ok {
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok {
|
||||
if c.correctToolCallsArray(toolCalls) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !corrected {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// 序列化回 JSON
|
||||
correctedBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
|
||||
return data, false
|
||||
}
|
||||
|
||||
return string(correctedBytes), true
|
||||
}
|
||||
|
||||
// correctToolCallsArray 修正工具调用数组中的工具名称
|
||||
func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
|
||||
corrected := false
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCallMap, ok := toolCall.(map[string]any); ok {
|
||||
if function, ok := toolCallMap["function"].(map[string]any); ok {
|
||||
if c.correctFunctionCall(function) {
|
||||
corrected = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return corrected
|
||||
}
|
||||
|
||||
// correctFunctionCall 修正单个函数调用的工具名称和参数
|
||||
func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
|
||||
name, ok := functionCall["name"].(string)
|
||||
if !ok || name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 查找并修正工具名称
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
functionCall["name"] = correctName
|
||||
c.recordCorrection(name, correctName)
|
||||
corrected = true
|
||||
name = correctName // 使用修正后的名称进行参数修正
|
||||
}
|
||||
|
||||
// 修正工具参数(基于工具名称)
|
||||
if c.correctToolParameters(name, functionCall) {
|
||||
corrected = true
|
||||
}
|
||||
|
||||
return corrected
|
||||
}
|
||||
|
||||
// correctToolParameters 修正工具参数以符合 OpenCode 规范
|
||||
func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
|
||||
arguments, ok := functionCall["arguments"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// arguments 可能是字符串(JSON)或已解析的 map
|
||||
var argsMap map[string]any
|
||||
switch v := arguments.(type) {
|
||||
case string:
|
||||
// 解析 JSON 字符串
|
||||
if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
|
||||
return false
|
||||
}
|
||||
case map[string]any:
|
||||
argsMap = v
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
corrected := false
|
||||
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// 移除 workdir 参数(OpenCode 不支持)
|
||||
if _, exists := argsMap["workdir"]; exists {
|
||||
delete(argsMap, "workdir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||
}
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||
}
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||
// 这里可以添加参数名称的映射逻辑
|
||||
if _, exists := argsMap["file_path"]; !exists {
|
||||
if path, exists := argsMap["path"]; exists {
|
||||
argsMap["file_path"] = path
|
||||
delete(argsMap, "path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果修正了参数,需要重新序列化
|
||||
if corrected {
|
||||
if _, wasString := arguments.(string); wasString {
|
||||
// 原本是字符串,序列化回字符串
|
||||
if newArgsJSON, err := json.Marshal(argsMap); err == nil {
|
||||
functionCall["arguments"] = string(newArgsJSON)
|
||||
}
|
||||
} else {
|
||||
// 原本是 map,直接赋值
|
||||
functionCall["arguments"] = argsMap
|
||||
}
|
||||
}
|
||||
|
||||
return corrected
|
||||
}
|
||||
|
||||
// recordCorrection 记录一次工具名称修正
|
||||
func (c *CodexToolCorrector) recordCorrection(from, to string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.stats.TotalCorrected++
|
||||
key := fmt.Sprintf("%s->%s", from, to)
|
||||
c.stats.CorrectionsByTool[key]++
|
||||
|
||||
log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
|
||||
from, to, c.stats.TotalCorrected)
|
||||
}
|
||||
|
||||
// GetStats 获取工具修正统计信息
|
||||
func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// 返回副本以避免并发问题
|
||||
statsCopy := ToolCorrectionStats{
|
||||
TotalCorrected: c.stats.TotalCorrected,
|
||||
CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
|
||||
}
|
||||
for k, v := range c.stats.CorrectionsByTool {
|
||||
statsCopy.CorrectionsByTool[k] = v
|
||||
}
|
||||
|
||||
return statsCopy
|
||||
}
|
||||
|
||||
// ResetStats 重置统计信息
|
||||
func (c *CodexToolCorrector) ResetStats() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.stats.TotalCorrected = 0
|
||||
c.stats.CorrectionsByTool = make(map[string]int)
|
||||
}
|
||||
|
||||
// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
|
||||
func CorrectToolName(name string) (string, bool) {
|
||||
if correctName, found := codexToolNameMapping[name]; found {
|
||||
return correctName, true
|
||||
}
|
||||
return name, false
|
||||
}
|
||||
|
||||
// GetToolNameMapping 获取工具名称映射表
|
||||
func GetToolNameMapping() map[string]string {
|
||||
// 返回副本以避免外部修改
|
||||
mapping := make(map[string]string, len(codexToolNameMapping))
|
||||
for k, v := range codexToolNameMapping {
|
||||
mapping[k] = v
|
||||
}
|
||||
return mapping
|
||||
}
|
||||
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
503
backend/internal/service/openai_tool_corrector_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCorrectToolCallsInSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectCorrected bool
|
||||
checkFunc func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "newline only",
|
||||
input: "\n",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
input: "not a json",
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "correct apply_patch in tool_calls",
|
||||
input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls, ok := payload["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in result")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct update_plan in function_call",
|
||||
input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
functionCall, ok := payload["function_call"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function_call format")
|
||||
}
|
||||
if functionCall["name"] != "todowrite" {
|
||||
t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct search_files in delta.tool_calls",
|
||||
input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
delta, ok := payload["delta"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid delta format")
|
||||
}
|
||||
toolCalls, ok := delta["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in delta")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "grep" {
|
||||
t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct list_files in choices.message.tool_calls",
|
||||
input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
t.Fatal("No choices found in result")
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid choice format")
|
||||
}
|
||||
message, ok := choice["message"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid message format")
|
||||
}
|
||||
toolCalls, ok := message["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in message")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "glob" {
|
||||
t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no correction needed",
|
||||
input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
|
||||
expectCorrected: false,
|
||||
},
|
||||
{
|
||||
name: "correct multiple tool calls",
|
||||
input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls, ok := payload["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) < 2 {
|
||||
t.Fatal("Expected at least 2 tool_calls")
|
||||
}
|
||||
|
||||
toolCall1, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid first tool_call format")
|
||||
}
|
||||
func1, ok := toolCall1["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid first function format")
|
||||
}
|
||||
if func1["name"] != "edit" {
|
||||
t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
|
||||
}
|
||||
|
||||
toolCall2, ok := toolCalls[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid second tool_call format")
|
||||
}
|
||||
func2, ok := toolCall2["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid second function format")
|
||||
}
|
||||
if func2["name"] != "read" {
|
||||
t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "camelCase format - applyPatch",
|
||||
input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
|
||||
expectCorrected: true,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
toolCalls, ok := payload["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in result")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
functionCall, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
if functionCall["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||
|
||||
if corrected != tt.expectCorrected {
|
||||
t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
|
||||
}
|
||||
|
||||
if !corrected && result != tt.input {
|
||||
t.Errorf("Expected unchanged result when not corrected")
|
||||
}
|
||||
|
||||
if tt.checkFunc != nil {
|
||||
tt.checkFunc(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectToolName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
corrected bool
|
||||
}{
|
||||
{"apply_patch", "edit", true},
|
||||
{"applyPatch", "edit", true},
|
||||
{"update_plan", "todowrite", true},
|
||||
{"updatePlan", "todowrite", true},
|
||||
{"read_plan", "todoread", true},
|
||||
{"readPlan", "todoread", true},
|
||||
{"search_files", "grep", true},
|
||||
{"searchFiles", "grep", true},
|
||||
{"list_files", "glob", true},
|
||||
{"listFiles", "glob", true},
|
||||
{"read_file", "read", true},
|
||||
{"readFile", "read", true},
|
||||
{"write_file", "write", true},
|
||||
{"writeFile", "write", true},
|
||||
{"execute_bash", "bash", true},
|
||||
{"executeBash", "bash", true},
|
||||
{"exec_bash", "bash", true},
|
||||
{"execBash", "bash", true},
|
||||
{"unknown_tool", "unknown_tool", false},
|
||||
{"read", "read", false},
|
||||
{"edit", "edit", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result, corrected := CorrectToolName(tt.input)
|
||||
|
||||
if corrected != tt.corrected {
|
||||
t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetToolNameMapping(t *testing.T) {
|
||||
mapping := GetToolNameMapping()
|
||||
|
||||
expectedMappings := map[string]string{
|
||||
"apply_patch": "edit",
|
||||
"update_plan": "todowrite",
|
||||
"read_plan": "todoread",
|
||||
"search_files": "grep",
|
||||
"list_files": "glob",
|
||||
}
|
||||
|
||||
for from, to := range expectedMappings {
|
||||
if mapping[from] != to {
|
||||
t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
|
||||
}
|
||||
}
|
||||
|
||||
mapping["test_tool"] = "test_value"
|
||||
newMapping := GetToolNameMapping()
|
||||
if _, exists := newMapping["test_tool"]; exists {
|
||||
t.Error("Modifications to returned mapping should not affect original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrectorStats(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
stats := corrector.GetStats()
|
||||
if stats.TotalCorrected != 0 {
|
||||
t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
|
||||
}
|
||||
if len(stats.CorrectionsByTool) != 0 {
|
||||
t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
|
||||
}
|
||||
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
|
||||
corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
|
||||
|
||||
stats = corrector.GetStats()
|
||||
if stats.TotalCorrected != 3 {
|
||||
t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
|
||||
t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
|
||||
}
|
||||
|
||||
if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
|
||||
t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
|
||||
}
|
||||
|
||||
corrector.ResetStats()
|
||||
stats = corrector.GetStats()
|
||||
if stats.TotalCorrected != 0 {
|
||||
t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
|
||||
}
|
||||
if len(stats.CorrectionsByTool) != 0 {
|
||||
t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexSSEData(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
input := `{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-5.1-codex",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"file\":\"test.go\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": null
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, corrected := corrector.CorrectToolCallsInSSEData(input)
|
||||
|
||||
if !corrected {
|
||||
t.Error("Expected data to be corrected")
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &payload); err != nil {
|
||||
t.Fatalf("Failed to parse result: %v", err)
|
||||
}
|
||||
|
||||
choices, ok := payload["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
t.Fatal("No choices found in result")
|
||||
}
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid choice format")
|
||||
}
|
||||
delta, ok := choice["delta"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid delta format")
|
||||
}
|
||||
toolCalls, ok := delta["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("No tool_calls found in delta")
|
||||
}
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid tool_call format")
|
||||
}
|
||||
function, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Invalid function format")
|
||||
}
|
||||
|
||||
if function["name"] != "edit" {
|
||||
t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCorrectToolParameters 测试工具参数修正
|
||||
func TestCorrectToolParameters(t *testing.T) {
|
||||
corrector := NewCodexToolCorrector()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||
}{
|
||||
{
|
||||
name: "remove workdir from bash tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"command": true,
|
||||
"workdir": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rename path to file_path in edit tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"file_path": true,
|
||||
"path": false,
|
||||
"old_string": true,
|
||||
"new_string": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
|
||||
if !changed {
|
||||
t.Error("expected data to be corrected")
|
||||
}
|
||||
|
||||
// 解析修正后的数据
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(corrected), &result); err != nil {
|
||||
t.Fatalf("failed to parse corrected data: %v", err)
|
||||
}
|
||||
|
||||
// 检查工具调用
|
||||
toolCalls, ok := result["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) == 0 {
|
||||
t.Fatal("no tool_calls found in corrected data")
|
||||
}
|
||||
|
||||
toolCall, ok := toolCalls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("invalid tool_call structure")
|
||||
}
|
||||
|
||||
function, ok := toolCall["function"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("no function found in tool_call")
|
||||
}
|
||||
|
||||
argumentsStr, ok := function["arguments"].(string)
|
||||
if !ok {
|
||||
t.Fatal("arguments is not a string")
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
|
||||
t.Fatalf("failed to parse arguments: %v", err)
|
||||
}
|
||||
|
||||
// 验证期望的参数
|
||||
for param, shouldExist := range tt.expected {
|
||||
_, exists := args[param]
|
||||
if shouldExist && !exists {
|
||||
t.Errorf("expected parameter %q to exist, but it doesn't", param)
|
||||
}
|
||||
if !shouldExist && exists {
|
||||
t.Errorf("expected parameter %q to not exist, but it does", param)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -235,11 +236,13 @@ func (s *OpsAggregationService) aggregateHourly() {
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggHourlyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
LastResult: &result,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -331,11 +334,13 @@ func (s *OpsAggregationService) aggregateDaily() {
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggDailyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
LastResult: &result,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -190,6 +190,13 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
return
|
||||
}
|
||||
|
||||
rulesTotal := len(rules)
|
||||
rulesEnabled := 0
|
||||
rulesEvaluated := 0
|
||||
eventsCreated := 0
|
||||
eventsResolved := 0
|
||||
emailsSent := 0
|
||||
|
||||
now := time.Now().UTC()
|
||||
safeEnd := now.Truncate(time.Minute)
|
||||
if safeEnd.IsZero() {
|
||||
@@ -205,8 +212,9 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
if rule == nil || !rule.Enabled || rule.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
rulesEnabled++
|
||||
|
||||
scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
|
||||
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
|
||||
|
||||
windowMinutes := rule.WindowMinutes
|
||||
if windowMinutes <= 0 {
|
||||
@@ -220,6 +228,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
s.resetRuleState(rule.ID, now)
|
||||
continue
|
||||
}
|
||||
rulesEvaluated++
|
||||
|
||||
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
|
||||
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
|
||||
@@ -236,6 +245,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Scoped silencing: if a matching silence exists, skip creating a firing event.
|
||||
if s.opsService != nil {
|
||||
platform := strings.TrimSpace(scopePlatform)
|
||||
region := scopeRegion
|
||||
if platform != "" {
|
||||
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||
@@ -267,8 +287,11 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
continue
|
||||
}
|
||||
|
||||
eventsCreated++
|
||||
if created != nil && created.ID > 0 {
|
||||
s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created)
|
||||
if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
|
||||
emailsSent++
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -278,11 +301,14 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
resolvedAt := now
|
||||
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||
} else {
|
||||
eventsResolved++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
||||
result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
|
||||
@@ -359,9 +385,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
|
||||
return required
|
||||
}
|
||||
|
||||
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
|
||||
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
|
||||
if filters == nil {
|
||||
return "", nil
|
||||
return "", nil, nil
|
||||
}
|
||||
if v, ok := filters["platform"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
@@ -392,7 +418,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
|
||||
}
|
||||
}
|
||||
}
|
||||
return platform, groupID
|
||||
if v, ok := filters["region"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
vv := strings.TrimSpace(s)
|
||||
if vv != "" {
|
||||
region = &vv
|
||||
}
|
||||
}
|
||||
}
|
||||
return platform, groupID, region
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||
@@ -504,16 +538,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||
return 0, false
|
||||
}
|
||||
return overview.UpstreamErrorRate * 100, true
|
||||
case "p95_latency_ms":
|
||||
if overview.Duration.P95 == nil {
|
||||
return 0, false
|
||||
}
|
||||
return float64(*overview.Duration.P95), true
|
||||
case "p99_latency_ms":
|
||||
if overview.Duration.P99 == nil {
|
||||
return 0, false
|
||||
}
|
||||
return float64(*overview.Duration.P99), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
@@ -576,32 +600,32 @@ func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes i
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) {
|
||||
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
|
||||
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if event.EmailSent {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if !rule.NotifyEmail {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
||||
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if len(emailCfg.Alert.Recipients) == 0 {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
|
||||
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
|
||||
return
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,6 +654,7 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
|
||||
if anySent {
|
||||
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
|
||||
}
|
||||
return anySent
|
||||
}
|
||||
|
||||
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||
@@ -797,7 +822,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
|
||||
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
@@ -805,11 +830,17 @@ func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, durat
|
||||
durMs := duration.Milliseconds()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
msg := strings.TrimSpace(result)
|
||||
if msg == "" {
|
||||
msg = "ok"
|
||||
}
|
||||
msg = truncateString(msg, 2048)
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAlertEvaluatorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
LastResult: &msg,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,9 @@ import "time"
|
||||
// with the existing ops dashboard frontend (backup style).
|
||||
|
||||
const (
|
||||
OpsAlertStatusFiring = "firing"
|
||||
OpsAlertStatusResolved = "resolved"
|
||||
OpsAlertStatusFiring = "firing"
|
||||
OpsAlertStatusResolved = "resolved"
|
||||
OpsAlertStatusManualResolved = "manual_resolved"
|
||||
)
|
||||
|
||||
type OpsAlertRule struct {
|
||||
@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type OpsAlertSilence struct {
|
||||
ID int64 `json:"id"`
|
||||
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Region *string `json:"region,omitempty"`
|
||||
|
||||
Until time.Time `json:"until"`
|
||||
Reason string `json:"reason"`
|
||||
|
||||
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEventFilter struct {
|
||||
Limit int
|
||||
|
||||
// Cursor pagination (descending by fired_at, then id).
|
||||
BeforeFiredAt *time.Time
|
||||
BeforeID *int64
|
||||
|
||||
// Optional filters.
|
||||
Status string
|
||||
Severity string
|
||||
Status string
|
||||
Severity string
|
||||
EmailSent *bool
|
||||
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
|
||||
return s.opsRepo.ListAlertEvents(ctx, filter)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ev == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
|
||||
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
|
||||
}
|
||||
if input.RuleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if strings.TrimSpace(input.Platform) == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
|
||||
}
|
||||
if input.Until.IsZero() {
|
||||
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if strings.TrimSpace(platform) == "" {
|
||||
return false, nil
|
||||
}
|
||||
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
|
||||
if eventID <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
status = strings.TrimSpace(status)
|
||||
if status == "" {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||
}
|
||||
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||
}
|
||||
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||
|
||||
@@ -149,7 +149,7 @@ func (s *OpsCleanupService) runScheduled() {
|
||||
log.Printf("[OpsCleanup] cleanup failed: %v", err)
|
||||
return
|
||||
}
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
|
||||
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
|
||||
}
|
||||
|
||||
@@ -330,12 +330,13 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
||||
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, counts opsCleanupDeletedCounts) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
result := truncateString(counts.String(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
@@ -343,6 +344,7 @@ func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration tim
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
LastResult: &result,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
|
||||
}
|
||||
|
||||
// computeBusinessHealth calculates business health score (0-100)
|
||||
// Components: SLA (50%) + Error Rate (30%) + Latency (20%)
|
||||
// Components: Error Rate (50%) + TTFT (50%)
|
||||
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
||||
// SLA score: 99.5% → 100, 95% → 0 (linear)
|
||||
slaScore := 100.0
|
||||
slaPct := clampFloat64(overview.SLA*100, 0, 100)
|
||||
if slaPct < 99.5 {
|
||||
if slaPct >= 95 {
|
||||
slaScore = (slaPct - 95) / 4.5 * 100
|
||||
} else {
|
||||
slaScore = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Error rate score: 0.5% → 100, 5% → 0 (linear)
|
||||
// Error rate score: 1% → 100, 10% → 0 (linear)
|
||||
// Combines request errors and upstream errors
|
||||
errorScore := 100.0
|
||||
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
||||
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
||||
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
||||
if combinedErrorPct > 0.5 {
|
||||
if combinedErrorPct <= 5 {
|
||||
errorScore = (5 - combinedErrorPct) / 4.5 * 100
|
||||
if combinedErrorPct > 1.0 {
|
||||
if combinedErrorPct <= 10.0 {
|
||||
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
|
||||
} else {
|
||||
errorScore = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Latency score: 1s → 100, 10s → 0 (linear)
|
||||
// Uses P99 of duration (TTFT is less critical for overall health)
|
||||
latencyScore := 100.0
|
||||
if overview.Duration.P99 != nil {
|
||||
p99 := float64(*overview.Duration.P99)
|
||||
// TTFT score: 1s → 100, 3s → 0 (linear)
|
||||
// Time to first token is critical for user experience
|
||||
ttftScore := 100.0
|
||||
if overview.TTFT.P99 != nil {
|
||||
p99 := float64(*overview.TTFT.P99)
|
||||
if p99 > 1000 {
|
||||
if p99 <= 10000 {
|
||||
latencyScore = (10000 - p99) / 9000 * 100
|
||||
if p99 <= 3000 {
|
||||
ttftScore = (3000 - p99) / 2000 * 100
|
||||
} else {
|
||||
latencyScore = 0
|
||||
ttftScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Weighted combination
|
||||
return slaScore*0.5 + errorScore*0.3 + latencyScore*0.2
|
||||
// Weighted combination: 50% error rate + 50% TTFT
|
||||
return errorScore*0.5 + ttftScore*0.5
|
||||
}
|
||||
|
||||
// computeInfraHealth calculates infrastructure health score (0-100)
|
||||
|
||||
@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
||||
MemoryUsagePercent: float64Ptr(75),
|
||||
},
|
||||
},
|
||||
wantMin: 60,
|
||||
wantMax: 85,
|
||||
wantMin: 96,
|
||||
wantMax: 97,
|
||||
},
|
||||
{
|
||||
name: "DB failure",
|
||||
@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
||||
MemoryUsagePercent: float64Ptr(30),
|
||||
},
|
||||
},
|
||||
wantMin: 25,
|
||||
wantMax: 50,
|
||||
wantMin: 84,
|
||||
wantMax: 85,
|
||||
},
|
||||
{
|
||||
name: "combined failures - business healthy + infra degraded",
|
||||
@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 50,
|
||||
wantMax: 60,
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "error rate boundary 0.5%",
|
||||
name: "error rate boundary 1%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0.005,
|
||||
SLA: 0.99,
|
||||
ErrorRate: 0.01,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 95,
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "latency boundary 1000ms",
|
||||
name: "error rate 5%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.995,
|
||||
SLA: 0.95,
|
||||
ErrorRate: 0.05,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 77,
|
||||
wantMax: 78,
|
||||
},
|
||||
{
|
||||
name: "TTFT boundary 2s",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.99,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(1000)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(2000)},
|
||||
},
|
||||
wantMin: 95,
|
||||
wantMax: 100,
|
||||
wantMin: 75,
|
||||
wantMax: 75,
|
||||
},
|
||||
{
|
||||
name: "upstream error dominates",
|
||||
@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
|
||||
UpstreamErrorRate: 0.03,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 75,
|
||||
wantMin: 88,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -6,24 +6,43 @@ type OpsErrorLog struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
Phase string `json:"phase"`
|
||||
Type string `json:"type"`
|
||||
// Standardized classification
|
||||
// - phase: request|auth|routing|upstream|network|internal
|
||||
// - owner: client|provider|platform
|
||||
// - source: client_request|upstream_http|gateway
|
||||
Phase string `json:"phase"`
|
||||
Type string `json:"type"`
|
||||
|
||||
Owner string `json:"error_owner"`
|
||||
Source string `json:"error_source"`
|
||||
|
||||
Severity string `json:"severity"`
|
||||
|
||||
StatusCode int `json:"status_code"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
|
||||
LatencyMs *int `json:"latency_ms"`
|
||||
IsRetryable bool `json:"is_retryable"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
|
||||
Resolved bool `json:"resolved"`
|
||||
ResolvedAt *time.Time `json:"resolved_at"`
|
||||
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
|
||||
ResolvedByUserName string `json:"resolved_by_user_name"`
|
||||
ResolvedRetryID *int64 `json:"resolved_retry_id"`
|
||||
ResolvedStatusRaw string `json:"-"`
|
||||
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Message string `json:"message"`
|
||||
|
||||
UserID *int64 `json:"user_id"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
|
||||
ClientIP *string `json:"client_ip"`
|
||||
RequestPath string `json:"request_path"`
|
||||
@@ -67,9 +86,24 @@ type OpsErrorLogFilter struct {
|
||||
GroupID *int64
|
||||
AccountID *int64
|
||||
|
||||
StatusCodes []int
|
||||
Phase string
|
||||
Query string
|
||||
StatusCodes []int
|
||||
StatusCodesOther bool
|
||||
Phase string
|
||||
Owner string
|
||||
Source string
|
||||
Resolved *bool
|
||||
Query string
|
||||
UserQuery string // Search by user email
|
||||
|
||||
// Optional correlation keys for exact matching.
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
|
||||
// View controls error categorization for list endpoints.
|
||||
// - errors: show actionable errors (exclude business-limited / 429 / 529)
|
||||
// - excluded: only show excluded errors
|
||||
// - all: show everything
|
||||
View string
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
|
||||
SourceErrorID int64 `json:"source_error_id"`
|
||||
Mode string `json:"mode"`
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
PinnedAccountName string `json:"pinned_account_name"`
|
||||
|
||||
Status string `json:"status"`
|
||||
StartedAt *time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at"`
|
||||
DurationMs *int64 `json:"duration_ms"`
|
||||
|
||||
// Persisted execution results (best-effort)
|
||||
Success *bool `json:"success"`
|
||||
HTTPStatusCode *int `json:"http_status_code"`
|
||||
UpstreamRequestID *string `json:"upstream_request_id"`
|
||||
UsedAccountID *int64 `json:"used_account_id"`
|
||||
UsedAccountName string `json:"used_account_name"`
|
||||
ResponsePreview *string `json:"response_preview"`
|
||||
ResponseTruncated *bool `json:"response_truncated"`
|
||||
|
||||
// Optional correlation
|
||||
ResultRequestID *string `json:"result_request_id"`
|
||||
ResultErrorID *int64 `json:"result_error_id"`
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ type OpsRepository interface {
|
||||
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
||||
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
||||
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
||||
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
|
||||
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
|
||||
|
||||
// Lightweight window stats (for realtime WS / quick sampling).
|
||||
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
||||
@@ -39,12 +41,17 @@ type OpsRepository interface {
|
||||
DeleteAlertRule(ctx context.Context, id int64) error
|
||||
|
||||
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
||||
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
|
||||
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
||||
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
||||
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
||||
|
||||
// Alert silences
|
||||
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
|
||||
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
|
||||
|
||||
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
||||
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||
@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
|
||||
// It is set by OpsService.RecordError before persisting.
|
||||
UpstreamErrorsJSON *string
|
||||
|
||||
DurationMs *int
|
||||
TimeToFirstTokenMs *int64
|
||||
|
||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||
@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
|
||||
FinishedAt time.Time
|
||||
DurationMs int64
|
||||
|
||||
// Optional correlation
|
||||
// Persisted execution results (best-effort)
|
||||
Success *bool
|
||||
HTTPStatusCode *int
|
||||
UpstreamRequestID *string
|
||||
UsedAccountID *int64
|
||||
ResponsePreview *string
|
||||
ResponseTruncated *bool
|
||||
|
||||
// Optional correlation (legacy fields kept)
|
||||
ResultRequestID *string
|
||||
ResultErrorID *int64
|
||||
|
||||
@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct {
|
||||
LastErrorAt *time.Time
|
||||
LastError *string
|
||||
LastDurationMs *int64
|
||||
|
||||
// LastResult is an optional human-readable summary of the last successful run.
|
||||
LastResult *string
|
||||
}
|
||||
|
||||
type OpsJobHeartbeat struct {
|
||||
@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct {
|
||||
LastErrorAt *time.Time `json:"last_error_at"`
|
||||
LastError *string `json:"last_error"`
|
||||
LastDurationMs *int64 `json:"last_duration_ms"`
|
||||
LastResult *string `json:"last_result"`
|
||||
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
|
||||
return w.totalWritten > int64(w.limit)
|
||||
}
|
||||
|
||||
const (
|
||||
OpsRetryModeUpstreamEvent = "upstream_event"
|
||||
)
|
||||
|
||||
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
|
||||
}
|
||||
|
||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if errorLog == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||
}
|
||||
|
||||
var pinned *int64
|
||||
if mode == OpsRetryModeUpstream {
|
||||
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||
pinned = pinnedAccountID
|
||||
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||
pinned = errorLog.AccountID
|
||||
} else {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
||||
}
|
||||
}
|
||||
|
||||
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
|
||||
}
|
||||
|
||||
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
|
||||
// idx is 0-based. It always pins the original event account_id.
|
||||
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if idx < 0 {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
|
||||
}
|
||||
|
||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if errorLog == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
|
||||
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
|
||||
if err != nil {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
|
||||
}
|
||||
if idx >= len(events) {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
|
||||
}
|
||||
ev := events[idx]
|
||||
if ev == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
|
||||
}
|
||||
if ev.AccountID <= 0 {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||
}
|
||||
|
||||
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
|
||||
if upstreamBody == "" {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
|
||||
}
|
||||
|
||||
override := *errorLog
|
||||
override.RequestBody = upstreamBody
|
||||
pinned := ev.AccountID
|
||||
|
||||
// Persist as upstream_event, execute as upstream pinned retry.
|
||||
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
|
||||
}
|
||||
|
||||
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
|
||||
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
|
||||
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
||||
}
|
||||
}
|
||||
|
||||
errorLog, err := s.GetErrorLogByID(ctx, errorID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||
if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
|
||||
}
|
||||
|
||||
var pinned *int64
|
||||
if mode == OpsRetryModeUpstream {
|
||||
if execMode == OpsRetryModeUpstream {
|
||||
if pinnedAccountID != nil && *pinnedAccountID > 0 {
|
||||
pinned = pinnedAccountID
|
||||
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
|
||||
pinned = errorLog.AccountID
|
||||
} else {
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
|
||||
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
||||
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
|
||||
defer cancel()
|
||||
|
||||
execRes := s.executeRetry(execCtx, errorLog, mode, pinned)
|
||||
execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
|
||||
|
||||
finishedAt := time.Now()
|
||||
result.FinishedAt = finishedAt
|
||||
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
|
||||
msg := result.ErrorMessage
|
||||
updateErrMsg = &msg
|
||||
}
|
||||
// Keep legacy result_request_id empty; use upstream_request_id instead.
|
||||
var resultRequestID *string
|
||||
if strings.TrimSpace(result.UpstreamRequestID) != "" {
|
||||
v := result.UpstreamRequestID
|
||||
resultRequestID = &v
|
||||
}
|
||||
|
||||
finalStatus := result.Status
|
||||
if strings.TrimSpace(finalStatus) == "" {
|
||||
finalStatus = opsRetryStatusFailed
|
||||
}
|
||||
|
||||
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
|
||||
httpStatus := result.HTTPStatusCode
|
||||
upstreamReqID := result.UpstreamRequestID
|
||||
usedAccountID := result.UsedAccountID
|
||||
preview := result.ResponsePreview
|
||||
truncated := result.ResponseTruncated
|
||||
|
||||
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
|
||||
ID: attemptID,
|
||||
Status: finalStatus,
|
||||
FinishedAt: finishedAt,
|
||||
DurationMs: result.DurationMs,
|
||||
ResultRequestID: resultRequestID,
|
||||
ErrorMessage: updateErrMsg,
|
||||
ID: attemptID,
|
||||
Status: finalStatus,
|
||||
FinishedAt: finishedAt,
|
||||
DurationMs: result.DurationMs,
|
||||
Success: &success,
|
||||
HTTPStatusCode: &httpStatus,
|
||||
UpstreamRequestID: &upstreamReqID,
|
||||
UsedAccountID: usedAccountID,
|
||||
ResponsePreview: &preview,
|
||||
ResponseTruncated: &truncated,
|
||||
ResultRequestID: resultRequestID,
|
||||
ErrorMessage: updateErrMsg,
|
||||
}); err != nil {
|
||||
// Best-effort: retry itself already executed; do not fail the API response.
|
||||
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
|
||||
} else if success {
|
||||
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
|
||||
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
|
||||
if s.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gateway service not available")
|
||||
}
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
|
||||
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
|
||||
}
|
||||
|
||||
@@ -177,6 +177,10 @@ func (s *OpsScheduledReportService) runOnce() {
|
||||
return
|
||||
}
|
||||
|
||||
reportsTotal := len(reports)
|
||||
reportsDue := 0
|
||||
sentAttempts := 0
|
||||
|
||||
for _, report := range reports {
|
||||
if report == nil || !report.Enabled {
|
||||
continue
|
||||
@@ -184,14 +188,18 @@ func (s *OpsScheduledReportService) runOnce() {
|
||||
if report.NextRunAt.After(now) {
|
||||
continue
|
||||
}
|
||||
reportsDue++
|
||||
|
||||
if err := s.runReport(ctx, report, now); err != nil {
|
||||
attempts, err := s.runReport(ctx, report, now)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
return
|
||||
}
|
||||
sentAttempts += attempts
|
||||
}
|
||||
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
||||
result := truncateString(fmt.Sprintf("reports=%d due=%d send_attempts=%d", reportsTotal, reportsDue, sentAttempts), 2048)
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
|
||||
}
|
||||
|
||||
type opsScheduledReport struct {
|
||||
@@ -297,9 +305,9 @@ func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, no
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) error {
|
||||
func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) (int, error) {
|
||||
if s == nil || s.opsService == nil || s.emailService == nil || report == nil {
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -310,11 +318,11 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
||||
|
||||
content, err := s.generateReportHTML(ctx, report, now)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
// Skip sending when the report decides not to emit content (e.g., digest below min count).
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
recipients := report.Recipients
|
||||
@@ -325,22 +333,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
|
||||
}
|
||||
}
|
||||
if len(recipients) == 0 {
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
|
||||
|
||||
attempts := 0
|
||||
for _, to := range recipients {
|
||||
addr := strings.TrimSpace(to)
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
attempts++
|
||||
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
|
||||
// Ignore per-recipient failures; continue best-effort.
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return attempts, nil
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
|
||||
@@ -650,7 +660,7 @@ func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType
|
||||
_ = s.redisClient.Set(ctx, key, strconv.FormatInt(t.UTC().Unix(), 10), 14*24*time.Hour).Err()
|
||||
}
|
||||
|
||||
func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
||||
func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
|
||||
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
@@ -658,11 +668,17 @@ func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, dura
|
||||
durMs := duration.Milliseconds()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
msg := strings.TrimSpace(result)
|
||||
if msg == "" {
|
||||
msg = "ok"
|
||||
}
|
||||
msg = truncateString(msg, 2048)
|
||||
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsScheduledReportJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
LastResult: &msg,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
|
||||
out.Detail = ""
|
||||
}
|
||||
|
||||
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
|
||||
if out.UpstreamRequestBody != "" {
|
||||
// Reuse the same sanitization/trimming strategy as request body storage.
|
||||
// Keep it small so it is safe to persist in ops_error_logs JSON.
|
||||
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
|
||||
if sanitized != "" {
|
||||
out.UpstreamRequestBody = sanitized
|
||||
if truncated {
|
||||
out.Kind = strings.TrimSpace(out.Kind)
|
||||
if out.Kind == "" {
|
||||
out.Kind = "upstream"
|
||||
}
|
||||
out.Kind = out.Kind + ":request_body_truncated"
|
||||
}
|
||||
} else {
|
||||
out.UpstreamRequestBody = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Drop fully-empty events (can happen if only status code was known).
|
||||
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
|
||||
continue
|
||||
@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter
|
||||
if s.opsRepo == nil {
|
||||
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
|
||||
}
|
||||
return s.opsRepo.ListErrorLogs(ctx, filter)
|
||||
result, err := s.opsRepo.ListErrorLogs(ctx, filter)
|
||||
if err != nil {
|
||||
log.Printf("[Ops] GetErrorLogs failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
|
||||
@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
|
||||
return detail, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if errorID <= 0 {
|
||||
return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
|
||||
}
|
||||
items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return []*OpsRetryAttempt{}, nil
|
||||
}
|
||||
return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if errorID <= 0 {
|
||||
return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
|
||||
}
|
||||
// Best-effort ensure the error exists
|
||||
if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
|
||||
}
|
||||
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
|
||||
}
|
||||
return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
|
||||
}
|
||||
|
||||
func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
|
||||
bytesLen = len(raw)
|
||||
if len(raw) == 0 {
|
||||
@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: store a minimal placeholder (still valid JSON).
|
||||
placeholder := map[string]any{
|
||||
"request_body_truncated": true,
|
||||
// Last resort: keep JSON shape but drop big fields.
|
||||
// This avoids downstream code that expects certain top-level keys from crashing.
|
||||
if root, ok := decoded.(map[string]any); ok {
|
||||
placeholder := shallowCopyMap(root)
|
||||
placeholder["request_body_truncated"] = true
|
||||
|
||||
// Replace potentially huge arrays/strings, but keep the keys present.
|
||||
for _, k := range []string{"messages", "contents", "input", "prompt"} {
|
||||
if _, exists := placeholder[k]; exists {
|
||||
placeholder[k] = []any{}
|
||||
}
|
||||
}
|
||||
for _, k := range []string{"text"} {
|
||||
if _, exists := placeholder[k]; exists {
|
||||
placeholder[k] = ""
|
||||
}
|
||||
}
|
||||
|
||||
encoded4, err4 := json.Marshal(placeholder)
|
||||
if err4 == nil {
|
||||
if len(encoded4) <= maxBytes {
|
||||
return string(encoded4), true, bytesLen
|
||||
}
|
||||
}
|
||||
}
|
||||
if model := extractString(decoded, "model"); model != "" {
|
||||
placeholder["model"] = model
|
||||
}
|
||||
encoded4, err4 := json.Marshal(placeholder)
|
||||
|
||||
// Final fallback: minimal valid JSON.
|
||||
encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
|
||||
if err4 != nil {
|
||||
return "", true, bytesLen
|
||||
}
|
||||
@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
|
||||
}
|
||||
return raw, false
|
||||
}
|
||||
|
||||
func extractString(v any, key string) string {
|
||||
root, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, _ := root[key].(string)
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
|
||||
Aggregation: OpsAggregationSettings{
|
||||
AggregationEnabled: false,
|
||||
},
|
||||
IgnoreCountTokensErrors: false,
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
IgnoreCountTokensErrors: false,
|
||||
IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
|
||||
IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
|
||||
AutoRefreshEnabled: false,
|
||||
AutoRefreshIntervalSec: 30,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
|
||||
|
||||
func defaultOpsMetricThresholds() *OpsMetricThresholds {
|
||||
slaMin := 99.5
|
||||
latencyMax := 2000.0
|
||||
ttftMax := 500.0
|
||||
reqErrMax := 5.0
|
||||
upstreamErrMax := 5.0
|
||||
return &OpsMetricThresholds{
|
||||
SLAPercentMin: &slaMin,
|
||||
LatencyP99MsMax: &latencyMax,
|
||||
TTFTp99MsMax: &ttftMax,
|
||||
RequestErrorRatePercentMax: &reqErrMax,
|
||||
UpstreamErrorRatePercentMax: &upstreamErrMax,
|
||||
@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT
|
||||
if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) {
|
||||
return nil, errors.New("sla_percent_min must be between 0 and 100")
|
||||
}
|
||||
if cfg.LatencyP99MsMax != nil && *cfg.LatencyP99MsMax < 0 {
|
||||
return nil, errors.New("latency_p99_ms_max must be >= 0")
|
||||
}
|
||||
if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 {
|
||||
return nil, errors.New("ttft_p99_ms_max must be >= 0")
|
||||
}
|
||||
|
||||
@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct {
|
||||
|
||||
type OpsMetricThresholds struct {
|
||||
SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红
|
||||
LatencyP99MsMax *float64 `json:"latency_p99_ms_max,omitempty"` // 延迟P99高于此值变红
|
||||
TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红
|
||||
RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红
|
||||
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
|
||||
@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct {
|
||||
|
||||
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
|
||||
type OpsAdvancedSettings struct {
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
DataRetention OpsDataRetentionSettings `json:"data_retention"`
|
||||
Aggregation OpsAggregationSettings `json:"aggregation"`
|
||||
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
|
||||
IgnoreContextCanceled bool `json:"ignore_context_canceled"`
|
||||
IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
|
||||
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
|
||||
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
|
||||
}
|
||||
|
||||
type OpsDataRetentionSettings struct {
|
||||
|
||||
@@ -15,6 +15,11 @@ const (
|
||||
OpsUpstreamErrorMessageKey = "ops_upstream_error_message"
|
||||
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
|
||||
OpsUpstreamErrorsKey = "ops_upstream_errors"
|
||||
|
||||
// Best-effort capture of the current upstream request body so ops can
|
||||
// retry the specific upstream attempt (not just the client request).
|
||||
// This value is sanitized+trimmed before being persisted.
|
||||
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
||||
)
|
||||
|
||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||
@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct {
|
||||
AtUnixMs int64 `json:"at_unix_ms,omitempty"`
|
||||
|
||||
// Context
|
||||
Platform string `json:"platform,omitempty"`
|
||||
AccountID int64 `json:"account_id,omitempty"`
|
||||
Platform string `json:"platform,omitempty"`
|
||||
AccountID int64 `json:"account_id,omitempty"`
|
||||
AccountName string `json:"account_name,omitempty"`
|
||||
|
||||
// Outcome
|
||||
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
|
||||
|
||||
// Best-effort upstream request capture (sanitized+trimmed).
|
||||
// Required for retrying a specific upstream attempt.
|
||||
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
|
||||
|
||||
// Best-effort upstream response capture (sanitized+trimmed).
|
||||
UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
|
||||
|
||||
// Kind: http_error | request_error | retry_exhausted | failover
|
||||
Kind string `json:"kind,omitempty"`
|
||||
|
||||
@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
||||
}
|
||||
ev.Platform = strings.TrimSpace(ev.Platform)
|
||||
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
|
||||
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
|
||||
ev.Kind = strings.TrimSpace(ev.Kind)
|
||||
ev.Message = strings.TrimSpace(ev.Message)
|
||||
ev.Detail = strings.TrimSpace(ev.Detail)
|
||||
@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
||||
ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
|
||||
}
|
||||
|
||||
// If the caller didn't explicitly pass upstream request body but the gateway
|
||||
// stored it on the context, attach it so ops can retry this specific attempt.
|
||||
if ev.UpstreamRequestBody == "" {
|
||||
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var existing []*OpsUpstreamErrorEvent
|
||||
if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
|
||||
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
|
||||
@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
|
||||
s := string(raw)
|
||||
return &s
|
||||
}
|
||||
|
||||
func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return []*OpsUpstreamErrorEvent{}, nil
|
||||
}
|
||||
var out []*OpsUpstreamErrorEvent
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64
|
||||
AccountCount int64
|
||||
LatencyMs *int64
|
||||
LatencyStatus string
|
||||
LatencyMessage string
|
||||
IPAddress string
|
||||
Country string
|
||||
CountryCode string
|
||||
Region string
|
||||
City string
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
ID int64
|
||||
Name string
|
||||
Platform string
|
||||
Type string
|
||||
Notes *string
|
||||
}
|
||||
|
||||
23
backend/internal/service/proxy_latency_cache.go
Normal file
23
backend/internal/service/proxy_latency_cache.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyLatencyInfo struct {
|
||||
Success bool `json:"success"`
|
||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type ProxyLatencyCache interface {
|
||||
GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*ProxyLatencyInfo, error)
|
||||
SetProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) error
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
|
||||
ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts")
|
||||
)
|
||||
|
||||
type ProxyRepository interface {
|
||||
@@ -26,6 +27,7 @@ type ProxyRepository interface {
|
||||
|
||||
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
|
||||
ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
|
||||
}
|
||||
|
||||
// CreateProxyRequest 创建代理请求
|
||||
|
||||
@@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -15,15 +15,16 @@ import (
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
timeoutCounterCache TimeoutCounterCache
|
||||
settingService *SettingService
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
timeoutCounterCache TimeoutCounterCache
|
||||
settingService *SettingService
|
||||
tokenCacheInvalidator TokenCacheInvalidator
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
|
||||
s.settingService = settingService
|
||||
}
|
||||
|
||||
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
|
||||
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
|
||||
s.tokenCacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
|
||||
slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
tempMatched := false
|
||||
if statusCode != 401 {
|
||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
}
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg != "" {
|
||||
@@ -76,7 +85,25 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
// 认证失败:停止调度,记录错误
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// 1. 失效缓存
|
||||
if s.tokenCacheInvalidator != nil {
|
||||
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
|
||||
} else {
|
||||
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
|
||||
}
|
||||
}
|
||||
msg := "Authentication failed (401): invalid or expired credentials"
|
||||
if upstreamMsg != "" {
|
||||
msg = "Authentication failed (401): " + upstreamMsg
|
||||
@@ -100,7 +127,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers)
|
||||
s.handle429(ctx, account, headers, responseBody)
|
||||
shouldDisable = false
|
||||
case 529:
|
||||
s.handle529(ctx, account)
|
||||
@@ -116,7 +143,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
shouldDisable = true
|
||||
} else if statusCode >= 500 {
|
||||
// 未启用自定义错误码时:仅记录5xx错误
|
||||
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
|
||||
slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
|
||||
shouldDisable = false
|
||||
}
|
||||
}
|
||||
@@ -163,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -188,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
// NOTE:
|
||||
// - This is a local precheck to reduce upstream 429s.
|
||||
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
@@ -210,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
|
||||
if limit > 0 {
|
||||
start := now.Truncate(time.Minute)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -231,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
if used >= limit {
|
||||
resetAt := start.Add(time.Minute)
|
||||
// Do not persist "rate limited" status from local precheck. See note above.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
@@ -288,32 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
|
||||
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg)
|
||||
slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
|
||||
}
|
||||
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 解析重置时间戳
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
if resetTimestamp == "" {
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
|
||||
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
|
||||
} else {
|
||||
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -321,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 解析Unix时间戳
|
||||
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
|
||||
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
|
||||
} else {
|
||||
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
|
||||
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
|
||||
return
|
||||
}
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -341,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
|
||||
if account == nil || account.Platform != PlatformAnthropic {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(msg, "sonnet")
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
@@ -357,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Account %d overloaded until %v", account.ID, until)
|
||||
slog.Info("account_overloaded", "account_id", account.ID, "until", until)
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
||||
@@ -384,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
end := start.Add(5 * time.Hour)
|
||||
windowStart = &start
|
||||
windowEnd = &end
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
||||
}
|
||||
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -404,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
|
||||
if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.accountRepo.ClearModelRateLimits(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
@@ -413,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
}
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
|
||||
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
|
||||
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -460,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
|
||||
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
|
||||
slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -563,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
|
||||
slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -597,13 +663,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
|
||||
|
||||
// 获取系统设置
|
||||
if s.settingService == nil {
|
||||
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
|
||||
slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[StreamTimeout] Failed to get settings: %v", err)
|
||||
slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -620,14 +686,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
|
||||
if s.timeoutCounterCache != nil {
|
||||
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
|
||||
if err != nil {
|
||||
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
|
||||
// 继续处理,使用 count=1
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
|
||||
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
|
||||
slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
|
||||
|
||||
// 检查是否达到阈值
|
||||
if count < int64(settings.ThresholdCount) {
|
||||
@@ -668,24 +733,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 重置超时计数
|
||||
if s.timeoutCounterCache != nil {
|
||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
|
||||
slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -694,17 +759,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
|
||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// 重置超时计数
|
||||
if s.timeoutCounterCache != nil {
|
||||
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
|
||||
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
|
||||
slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
|
||||
slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
|
||||
return true
|
||||
}
|
||||
|
||||
121
backend/internal/service/ratelimit_service_401_test.go
Normal file
121
backend/internal/service/ratelimit_service_401_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type rateLimitAccountRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
setErrorCalls int
|
||||
tempCalls int
|
||||
lastErrorMsg string
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
r.lastErrorMsg = errorMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
type tokenCacheInvalidatorRecorder struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
r.accounts = append(r.accounts, account)
|
||||
return r.err
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
}
|
||||
63
backend/internal/service/session_limit_cache.go
Normal file
63
backend/internal/service/session_limit_cache.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionLimitCache 管理账号级别的活跃会话跟踪
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
|
||||
//
|
||||
// Key 格式: session_limit:account:{accountID}
|
||||
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
|
||||
//
|
||||
// 会话在空闲超时后自动过期,无需手动清理
|
||||
type SessionLimitCache interface {
|
||||
// RegisterSession 注册会话活动
|
||||
// - 如果会话已存在,刷新其时间戳并返回 true
|
||||
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
|
||||
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
|
||||
//
|
||||
// 参数:
|
||||
// accountID: 账号 ID
|
||||
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
|
||||
// maxSessions: 最大并发会话数限制
|
||||
// idleTimeout: 会话空闲超时时间
|
||||
//
|
||||
// 返回:
|
||||
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
// error: 操作错误
|
||||
RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
|
||||
|
||||
// RefreshSession 刷新现有会话的时间戳
|
||||
// 用于活跃会话保持活动状态
|
||||
RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
|
||||
|
||||
// GetActiveSessionCount 获取当前活跃会话数
|
||||
// 返回未过期的会话数量
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
// ========== 5h窗口费用缓存 ==========
|
||||
// Key 格式: window_cost:account:{accountID}
|
||||
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
|
||||
|
||||
// GetWindowCost 获取缓存的窗口费用
|
||||
// 返回 (cost, true, nil) 如果缓存命中
|
||||
// 返回 (0, false, nil) 如果缓存未命中
|
||||
// 返回 (0, false, err) 如果发生错误
|
||||
GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
|
||||
|
||||
// SetWindowCost 设置窗口费用缓存
|
||||
SetWindowCost(ctx context.Context, accountID int64, cost float64) error
|
||||
|
||||
// GetWindowCostBatch 批量获取窗口费用缓存
|
||||
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
|
||||
GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
var newTimingWheel = collection.NewTimingWheel
|
||||
|
||||
// TimingWheelService wraps go-zero's TimingWheel for task scheduling
|
||||
type TimingWheelService struct {
|
||||
tw *collection.TimingWheel
|
||||
@@ -15,18 +18,18 @@ type TimingWheelService struct {
|
||||
}
|
||||
|
||||
// NewTimingWheelService creates a new TimingWheelService instance
|
||||
func NewTimingWheelService() *TimingWheelService {
|
||||
func NewTimingWheelService() (*TimingWheelService, error) {
|
||||
// 1 second tick, 3600 slots = supports up to 1 hour delay
|
||||
// execute function: runs func() type tasks
|
||||
tw, err := collection.NewTimingWheel(1*time.Second, 3600, func(key, value any) {
|
||||
tw, err := newTimingWheel(1*time.Second, 3600, func(key, value any) {
|
||||
if fn, ok := value.(func()); ok {
|
||||
fn()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return nil, fmt.Errorf("创建 timing wheel 失败: %w", err)
|
||||
}
|
||||
return &TimingWheelService{tw: tw}
|
||||
return &TimingWheelService{tw: tw}, nil
|
||||
}
|
||||
|
||||
// Start starts the timing wheel
|
||||
|
||||
146
backend/internal/service/timing_wheel_service_test.go
Normal file
146
backend/internal/service/timing_wheel_service_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
func TestNewTimingWheelService_InitFail_NoPanicAndReturnError(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err == nil {
|
||||
t.Fatalf("期望返回 error,但得到 nil")
|
||||
}
|
||||
if svc != nil {
|
||||
t.Fatalf("期望返回 nil svc,但得到非空")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTimingWheelService_Success(t *testing.T) {
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
if svc == nil {
|
||||
t.Fatalf("期望 svc 非空,但得到 nil")
|
||||
}
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func TestNewTimingWheelService_ExecuteCallbackRunsFunc(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
var captured collection.Execute
|
||||
newTimingWheel = func(interval time.Duration, numSlots int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
captured = execute
|
||||
return original(interval, numSlots, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
if captured == nil {
|
||||
t.Fatalf("期望 captured 非空,但得到 nil")
|
||||
}
|
||||
|
||||
called := false
|
||||
captured("k", func() { called = true })
|
||||
if !called {
|
||||
t.Fatalf("期望 execute 回调触发传入函数执行")
|
||||
}
|
||||
|
||||
svc.Stop()
|
||||
}
|
||||
|
||||
func TestTimingWheelService_Schedule_ExecutesOnce(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
return original(10*time.Millisecond, 128, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
defer svc.Stop()
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
svc.Schedule("once", 30*time.Millisecond, func() { ch <- struct{}{} })
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("等待任务执行超时")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatalf("任务不应重复执行")
|
||||
case <-time.After(80 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheelService_Cancel_PreventsExecution(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
return original(10*time.Millisecond, 128, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
defer svc.Stop()
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
svc.Schedule("cancel", 80*time.Millisecond, func() { ch <- struct{}{} })
|
||||
svc.Cancel("cancel")
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatalf("任务已取消,不应执行")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
|
||||
return original(10*time.Millisecond, 128, execute)
|
||||
}
|
||||
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
defer svc.Stop()
|
||||
|
||||
var count int32
|
||||
svc.ScheduleRecurring("rec", 30*time.Millisecond, func() { atomic.AddInt32(&count, 1) })
|
||||
|
||||
deadline := time.Now().Add(500 * time.Millisecond)
|
||||
for atomic.LoadInt32(&count) < 2 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if atomic.LoadInt32(&count) < 2 {
|
||||
t.Fatalf("期望周期任务至少执行 2 次,但只执行了 %d 次", atomic.LoadInt32(&count))
|
||||
}
|
||||
}
|
||||
41
backend/internal/service/token_cache_invalidator.go
Normal file
41
backend/internal/service/token_cache_invalidator.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
type TokenCacheInvalidator interface {
|
||||
InvalidateToken(ctx context.Context, account *Account) error
|
||||
}
|
||||
|
||||
type CompositeTokenCacheInvalidator struct {
|
||||
cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台
|
||||
}
|
||||
|
||||
func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator {
|
||||
return &CompositeTokenCacheInvalidator{
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
if c == nil || c.cache == nil || account == nil {
|
||||
return nil
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cacheKey string
|
||||
switch account.Platform {
|
||||
case PlatformGemini:
|
||||
cacheKey = GeminiTokenCacheKey(account)
|
||||
case PlatformAntigravity:
|
||||
cacheKey = AntigravityTokenCacheKey(account)
|
||||
case PlatformOpenAI:
|
||||
cacheKey = OpenAITokenCacheKey(account)
|
||||
case PlatformAnthropic:
|
||||
cacheKey = ClaudeTokenCacheKey(account)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return c.cache.DeleteAccessToken(ctx, cacheKey)
|
||||
}
|
||||
268
backend/internal/service/token_cache_invalidator_test.go
Normal file
268
backend/internal/service/token_cache_invalidator_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type geminiTokenCacheStub struct {
|
||||
deletedKeys []string
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
s.deletedKeys = append(s.deletedKeys, cacheKey)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "project-x",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 99,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "ag-project",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 500,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "openai-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 600,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "claude-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
}{
|
||||
{
|
||||
name: "gemini_api_key",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "openai_api_key",
|
||||
account: &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude_api_key",
|
||||
account: &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude_setup_token",
|
||||
account: &Account{
|
||||
ID: 4,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache.deletedKeys = nil
|
||||
err := invalidator.InvalidateToken(context.Background(), tt.account)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cache.deletedKeys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: "unknown-platform",
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
|
||||
invalidator := NewCompositeTokenCacheInvalidator(nil)
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) {
|
||||
var invalidator *CompositeTokenCacheInvalidator
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
|
||||
expectedErr := errors.New("redis connection failed")
|
||||
cache := &geminiTokenCacheStub{deleteErr: expectedErr}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
}{
|
||||
{
|
||||
name: "openai_delete_error",
|
||||
account: &Account{
|
||||
ID: 700,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude_delete_error",
|
||||
account: &Account{
|
||||
ID: 800,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := invalidator.InvalidateToken(context.Background(), tt.account)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, expectedErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
|
||||
// 测试所有平台的缓存键生成和删除
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
|
||||
accounts := []*Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}},
|
||||
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}},
|
||||
{ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
|
||||
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
|
||||
}
|
||||
|
||||
expectedKeys := []string{
|
||||
"gemini:gemini-proj",
|
||||
"ag:ag-proj",
|
||||
"openai:account:3",
|
||||
"claude:account:4",
|
||||
}
|
||||
|
||||
for _, acc := range accounts {
|
||||
err := invalidator.InvalidateToken(context.Background(), acc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, expectedKeys, cache.deletedKeys)
|
||||
}
|
||||
15
backend/internal/service/token_cache_key.go
Normal file
15
backend/internal/service/token_cache_key.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package service
|
||||
|
||||
import "strconv"
|
||||
|
||||
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
|
||||
// 格式: "openai:account:{account_id}"
|
||||
func OpenAITokenCacheKey(account *Account) string {
|
||||
return "openai:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
|
||||
// 格式: "claude:account:{account_id}"
|
||||
func ClaudeTokenCacheKey(account *Account) string {
|
||||
return "claude:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
259
backend/internal/service/token_cache_key_test.go
Normal file
259
backend/internal/service/token_cache_key_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiTokenCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with_project_id",
|
||||
account: &Account{
|
||||
ID: 100,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "my-project-123",
|
||||
},
|
||||
},
|
||||
expected: "gemini:my-project-123",
|
||||
},
|
||||
{
|
||||
name: "project_id_with_whitespace",
|
||||
account: &Account{
|
||||
ID: 101,
|
||||
Credentials: map[string]any{
|
||||
"project_id": " project-with-spaces ",
|
||||
},
|
||||
},
|
||||
expected: "gemini:project-with-spaces",
|
||||
},
|
||||
{
|
||||
name: "empty_project_id_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 102,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "",
|
||||
},
|
||||
},
|
||||
expected: "gemini:account:102",
|
||||
},
|
||||
{
|
||||
name: "whitespace_only_project_id_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 103,
|
||||
Credentials: map[string]any{
|
||||
"project_id": " ",
|
||||
},
|
||||
},
|
||||
expected: "gemini:account:103",
|
||||
},
|
||||
{
|
||||
name: "no_project_id_key_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 104,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: "gemini:account:104",
|
||||
},
|
||||
{
|
||||
name: "nil_credentials_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 105,
|
||||
Credentials: nil,
|
||||
},
|
||||
expected: "gemini:account:105",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GeminiTokenCacheKey(tt.account)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityTokenCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with_project_id",
|
||||
account: &Account{
|
||||
ID: 200,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "ag-project-456",
|
||||
},
|
||||
},
|
||||
expected: "ag:ag-project-456",
|
||||
},
|
||||
{
|
||||
name: "project_id_with_whitespace",
|
||||
account: &Account{
|
||||
ID: 201,
|
||||
Credentials: map[string]any{
|
||||
"project_id": " ag-project-spaces ",
|
||||
},
|
||||
},
|
||||
expected: "ag:ag-project-spaces",
|
||||
},
|
||||
{
|
||||
name: "empty_project_id_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 202,
|
||||
Credentials: map[string]any{
|
||||
"project_id": "",
|
||||
},
|
||||
},
|
||||
expected: "ag:account:202",
|
||||
},
|
||||
{
|
||||
name: "whitespace_only_project_id_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 203,
|
||||
Credentials: map[string]any{
|
||||
"project_id": " ",
|
||||
},
|
||||
},
|
||||
expected: "ag:account:203",
|
||||
},
|
||||
{
|
||||
name: "no_project_id_key_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 204,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: "ag:account:204",
|
||||
},
|
||||
{
|
||||
name: "nil_credentials_fallback_to_account_id",
|
||||
account: &Account{
|
||||
ID: 205,
|
||||
Credentials: nil,
|
||||
},
|
||||
expected: "ag:account:205",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AntigravityTokenCacheKey(tt.account)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAITokenCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic_account",
|
||||
account: &Account{
|
||||
ID: 300,
|
||||
},
|
||||
expected: "openai:account:300",
|
||||
},
|
||||
{
|
||||
name: "account_with_credentials",
|
||||
account: &Account{
|
||||
ID: 301,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
},
|
||||
},
|
||||
expected: "openai:account:301",
|
||||
},
|
||||
{
|
||||
name: "account_id_zero",
|
||||
account: &Account{
|
||||
ID: 0,
|
||||
},
|
||||
expected: "openai:account:0",
|
||||
},
|
||||
{
|
||||
name: "large_account_id",
|
||||
account: &Account{
|
||||
ID: 9999999999,
|
||||
},
|
||||
expected: "openai:account:9999999999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OpenAITokenCacheKey(tt.account)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeTokenCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic_account",
|
||||
account: &Account{
|
||||
ID: 400,
|
||||
},
|
||||
expected: "claude:account:400",
|
||||
},
|
||||
{
|
||||
name: "account_with_credentials",
|
||||
account: &Account{
|
||||
ID: 401,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "claude-token",
|
||||
},
|
||||
},
|
||||
expected: "claude:account:401",
|
||||
},
|
||||
{
|
||||
name: "account_id_zero",
|
||||
account: &Account{
|
||||
ID: 0,
|
||||
},
|
||||
expected: "claude:account:0",
|
||||
},
|
||||
{
|
||||
name: "large_account_id",
|
||||
account: &Account{
|
||||
ID: 9999999999,
|
||||
},
|
||||
expected: "claude:account:9999999999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ClaudeTokenCacheKey(tt.account)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheKeyUniqueness(t *testing.T) {
|
||||
// 确保不同平台的缓存键不会冲突
|
||||
account := &Account{ID: 123}
|
||||
|
||||
openaiKey := OpenAITokenCacheKey(account)
|
||||
claudeKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different")
|
||||
require.Contains(t, openaiKey, "openai:")
|
||||
require.Contains(t, claudeKey, "claude:")
|
||||
}
|
||||
@@ -14,9 +14,10 @@ import (
|
||||
// TokenRefreshService OAuth token自动刷新服务
|
||||
// 定期检查并刷新即将过期的token
|
||||
type TokenRefreshService struct {
|
||||
accountRepo AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
cfg *config.TokenRefreshConfig
|
||||
accountRepo AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
cfg *config.TokenRefreshConfig
|
||||
cacheInvalidator TokenCacheInvalidator
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
cacheInvalidator TokenCacheInvalidator,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
accountRepo: accountRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
cacheInvalidator: cacheInvalidator,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
361
backend/internal/service/token_refresh_service_test.go
Normal file
361
backend/internal/service/token_refresh_service_test.go
Normal file
@@ -0,0 +1,361 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type tokenRefreshAccountRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updateCalls int
|
||||
setErrorCalls int
|
||||
lastAccount *Account
|
||||
updateErr error
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.lastAccount = account
|
||||
return r.updateErr
|
||||
}
|
||||
|
||||
func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
type tokenCacheInvalidatorStub struct {
|
||||
calls int
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
s.calls++
|
||||
return s.err
|
||||
}
|
||||
|
||||
type tokenRefresherStub struct {
|
||||
credentials map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
return r.credentials, nil
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "new-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
require.Equal(t, "new-token", account.GetCredential("access_token"))
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 6,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls)
|
||||
}
|
||||
|
||||
func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg)
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
|
||||
func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 8,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "ag-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
|
||||
func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 9,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey, // 非 OAuth
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效
|
||||
func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformOpenAI, // OpenAI OAuth 账户
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
|
||||
func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to save credentials")
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
|
||||
func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 2,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("refresh failed"),
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
|
||||
require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
|
||||
require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
|
||||
func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("network error"), // 可重试错误
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
|
||||
}
|
||||
|
||||
// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
|
||||
func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
|
||||
repo := &tokenRefreshAccountRepo{}
|
||||
invalidator := &tokenCacheInvalidatorStub{}
|
||||
cfg := &config.Config{
|
||||
TokenRefresh: config.TokenRefreshConfig{
|
||||
MaxRetries: 3,
|
||||
RetryBackoffSeconds: 0,
|
||||
},
|
||||
}
|
||||
service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
refresher := &tokenRefresherStub{
|
||||
err: errors.New("invalid_grant: token revoked"), // 不可重试错误
|
||||
}
|
||||
|
||||
err := service.refreshWithRetry(context.Background(), account, refresher)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, repo.updateCalls)
|
||||
require.Equal(t, 0, invalidator.calls)
|
||||
require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
|
||||
}
|
||||
|
||||
// TestIsNonRetryableRefreshError 测试不可重试错误判断
|
||||
func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{name: "nil_error", err: nil, expected: false},
|
||||
{name: "network_error", err: errors.New("network timeout"), expected: false},
|
||||
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
|
||||
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
|
||||
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
|
||||
{name: "access_denied", err: errors.New("access_denied"), expected: true},
|
||||
{name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
|
||||
{name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isNonRetryableRefreshError(tt.err)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -33,6 +33,8 @@ type UsageLog struct {
|
||||
TotalCost float64
|
||||
ActualCost float64
|
||||
RateMultiplier float64
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
|
||||
AccountRateMultiplier *float64
|
||||
|
||||
BillingType int8
|
||||
Stream bool
|
||||
|
||||
@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
cacheInvalidator TokenCacheInvalidator,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@@ -64,10 +65,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
|
||||
}
|
||||
|
||||
// ProvideTimingWheelService creates and starts TimingWheelService
|
||||
func ProvideTimingWheelService() *TimingWheelService {
|
||||
svc := NewTimingWheelService()
|
||||
func ProvideTimingWheelService() (*TimingWheelService, error) {
|
||||
svc, err := NewTimingWheelService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
svc.Start()
|
||||
return svc
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// ProvideDeferredService creates and starts DeferredService
|
||||
@@ -108,10 +112,12 @@ func ProvideRateLimitService(
|
||||
tempUnschedCache TempUnschedCache,
|
||||
timeoutCounterCache TimeoutCounterCache,
|
||||
settingService *SettingService,
|
||||
tokenCacheInvalidator TokenCacheInvalidator,
|
||||
) *RateLimitService {
|
||||
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
|
||||
svc.SetTimeoutCounterCache(timeoutCounterCache)
|
||||
svc.SetSettingService(settingService)
|
||||
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -210,10 +216,14 @@ var ProviderSet = wire.NewSet(
|
||||
NewOpenAIOAuthService,
|
||||
NewGeminiOAuthService,
|
||||
NewGeminiQuotaService,
|
||||
NewCompositeTokenCacheInvalidator,
|
||||
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
|
||||
NewAntigravityOAuthService,
|
||||
NewGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
NewAntigravityTokenProvider,
|
||||
NewOpenAITokenProvider,
|
||||
NewClaudeTokenProvider,
|
||||
NewAntigravityGatewayService,
|
||||
ProvideRateLimitService,
|
||||
NewAccountUsageService,
|
||||
|
||||
37
backend/internal/service/wire_test.go
Normal file
37
backend/internal/service/wire_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
)
|
||||
|
||||
func TestProvideTimingWheelService_ReturnsError(t *testing.T) {
|
||||
original := newTimingWheel
|
||||
t.Cleanup(func() { newTimingWheel = original })
|
||||
|
||||
newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
|
||||
svc, err := ProvideTimingWheelService()
|
||||
if err == nil {
|
||||
t.Fatalf("期望返回 error,但得到 nil")
|
||||
}
|
||||
if svc != nil {
|
||||
t.Fatalf("期望返回 nil svc,但得到非空")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvideTimingWheelService_Success(t *testing.T) {
|
||||
svc, err := ProvideTimingWheelService()
|
||||
if err != nil {
|
||||
t.Fatalf("期望 err 为 nil,但得到: %v", err)
|
||||
}
|
||||
if svc == nil {
|
||||
t.Fatalf("期望 svc 非空,但得到 nil")
|
||||
}
|
||||
svc.Stop()
|
||||
}
|
||||
Reference in New Issue
Block a user