Files
sub2api/backend/internal/service/gateway_service.go
IanShaw 45bd9ac705 运维监控系统安全加固和功能优化 (#21)
* fix(ops): 修复运维监控系统的关键安全和稳定性问题

## 修复内容

### P0 严重问题
1. **DNS Rebinding防护** (ops_alert_service.go)
   - 实现IP钉住机制防止验证后的DNS rebinding攻击
   - 自定义Transport.DialContext强制只允许拨号到验证过的公网IP
   - 扩展IP黑名单,包括云metadata地址(169.254.169.254)
   - 添加完整的单元测试覆盖

2. **OpsAlertService生命周期管理** (wire.go)
   - 在ProvideOpsMetricsCollector中添加opsAlertService.Start()调用
   - 确保stopCtx正确初始化,避免nil指针问题
   - 实现防御式启动,保证服务启动顺序

3. **数据库查询排序** (ops_repo.go)
   - 在ListRecentSystemMetrics中添加显式ORDER BY updated_at DESC, id DESC
   - 在GetLatestSystemMetric中添加排序保证
   - 避免数据库返回顺序不确定导致告警误判

### P1 重要问题
4. **并发安全** (ops_metrics_collector.go)
   - 为lastGCPauseTotal字段添加sync.Mutex保护
   - 防止数据竞争

5. **Goroutine泄漏** (ops_error_logger.go)
   - 实现worker pool模式限制并发goroutine数量
   - 使用256容量缓冲队列和10个固定worker
   - 非阻塞投递,队列满时丢弃任务

6. **生命周期控制** (ops_alert_service.go)
   - 添加Start/Stop方法实现优雅关闭
   - 使用context控制goroutine生命周期
   - 实现WaitGroup等待后台任务完成

7. **Webhook URL验证** (ops_alert_service.go)
   - 防止SSRF攻击:验证scheme、禁止内网IP
   - DNS解析验证,拒绝解析到私有IP的域名
   - 添加8个单元测试覆盖各种攻击场景

8. **资源泄漏** (ops_repo.go)
   - 修复多处defer rows.Close()问题
   - 简化冗余的defer func()包装

9. **HTTP超时控制** (ops_alert_service.go)
   - 创建带10秒超时的http.Client
   - 添加buildWebhookHTTPClient辅助函数
   - 防止HTTP请求无限期挂起

10. **数据库查询优化** (ops_repo.go)
    - 将GetWindowStats的4次独立查询合并为1次CTE查询
    - 减少网络往返和表扫描次数
    - 显著提升性能

11. **重试机制** (ops_alert_service.go)
    - 实现邮件发送重试:最多3次,指数退避(1s/2s/4s)
    - 添加webhook备用通道
    - 实现完整的错误处理和日志记录

12. **魔法数字** (ops_repo.go, ops_metrics_collector.go)
    - 提取硬编码数字为有意义的常量
    - 提高代码可读性和可维护性

## 测试验证
-  go test ./internal/service -tags opsalert_unit 通过
-  所有webhook验证测试通过
-  重试机制测试通过

## 影响范围
- 运维监控系统安全性显著提升
- 系统稳定性和性能优化
- 无破坏性变更,向后兼容

* feat(ops): 运维监控系统V2 - 完整实现

## 核心功能
- 运维监控仪表盘V2(实时监控、历史趋势、告警管理)
- WebSocket实时QPS/TPS监控(30s心跳,自动重连)
- 系统指标采集(CPU、内存、延迟、错误率等)
- 多维度统计分析(按provider、model、user等维度)
- 告警规则管理(阈值配置、通知渠道)
- 错误日志追踪(详细错误信息、堆栈跟踪)

## 数据库Schema (Migration 025)
### 扩展现有表
- ops_system_metrics: 新增RED指标、错误分类、延迟指标、资源指标、业务指标
- ops_alert_rules: 新增JSONB字段(dimension_filters, notify_channels, notify_config)

### 新增表
- ops_dimension_stats: 多维度统计数据
- ops_data_retention_config: 数据保留策略配置

### 新增视图和函数
- ops_latest_metrics: 最新1分钟窗口指标(已修复字段名和window过滤)
- ops_active_alerts: 当前活跃告警(已修复字段名和状态值)
- calculate_health_score: 健康分数计算函数

## 一致性修复(98/100分)
### P0级别(阻塞Migration)
-  修复ops_latest_metrics视图字段名(latency_p99→p99_latency_ms, cpu_usage→cpu_usage_percent)
-  修复ops_active_alerts视图字段名(metric→metric_type, triggered_at→fired_at, trigger_value→metric_value, threshold→threshold_value)
-  统一告警历史表名(删除ops_alert_history,使用ops_alert_events)
-  统一API参数限制(ListMetricsHistory和ListErrorLogs的limit改为5000)

### P1级别(功能完整性)
-  修复ops_latest_metrics视图未过滤window_minutes(添加WHERE m.window_minutes = 1)
-  修复数据回填UPDATE逻辑(QPS计算改为request_count/(window_minutes*60.0))
-  添加ops_alert_rules JSONB字段后端支持(Go结构体+序列化)

### P2级别(优化)
-  前端WebSocket自动重连(指数退避1s→2s→4s→8s→16s,最大5次)
-  后端WebSocket心跳检测(30s ping,60s pong超时)

## 技术实现
### 后端 (Go)
- Handler层: ops_handler.go(REST API), ops_ws_handler.go(WebSocket)
- Service层: ops_service.go(核心逻辑), ops_cache.go(缓存), ops_alerts.go(告警)
- Repository层: ops_repo.go(数据访问), ops.go(模型定义)
- 路由: admin.go(新增ops相关路由)
- 依赖注入: wire_gen.go(自动生成)

### 前端 (Vue3 + TypeScript)
- 组件: OpsDashboardV2.vue(仪表盘主组件)
- API: ops.ts(REST API + WebSocket封装)
- 路由: index.ts(新增/admin/ops路由)
- 国际化: en.ts, zh.ts(中英文支持)

## 测试验证
-  所有Go测试通过
-  Migration可正常执行
-  WebSocket连接稳定
-  前后端数据结构对齐

* refactor: 代码清理和测试优化

## 测试文件优化
- 简化integration test fixtures和断言
- 优化test helper函数
- 统一测试数据格式

## 代码清理
- 移除未使用的代码和注释
- 简化concurrency_cache实现
- 优化middleware错误处理

## 小修复
- 修复gateway_handler和openai_gateway_handler的小问题
- 统一代码风格和格式

变更统计: 27个文件,292行新增,322行删除(净减少30行)

* fix(ops): 运维监控系统安全加固和功能优化

## 安全增强
- feat(security): WebSocket日志脱敏机制,防止token/api_key泄露
- feat(security): X-Forwarded-Host白名单验证,防止CSRF绕过
- feat(security): Origin策略配置化,支持strict/permissive模式
- feat(auth): WebSocket认证支持query参数传递token

## 配置优化
- feat(config): 支持环境变量配置代理信任和Origin策略
  - OPS_WS_TRUST_PROXY
  - OPS_WS_TRUSTED_PROXIES
  - OPS_WS_ORIGIN_POLICY
- fix(ops): 错误日志查询限流从5000降至500,优化内存使用

## 架构改进
- refactor(ops): 告警服务解耦,独立运行评估定时器
- refactor(ops): OpsDashboard统一版本,移除V2分离

## 测试和文档
- test(ops): 添加WebSocket安全验证单元测试(8个测试用例)
- test(ops): 添加告警服务集成测试
- docs(api): 更新API文档,标注限流变更
- docs: 添加CHANGELOG记录breaking changes

## 修复文件
Backend:
- backend/internal/server/middleware/logger.go
- backend/internal/handler/admin/ops_handler.go
- backend/internal/handler/admin/ops_ws_handler.go
- backend/internal/server/middleware/admin_auth.go
- backend/internal/service/ops_alert_service.go
- backend/internal/service/ops_metrics_collector.go
- backend/internal/service/wire.go

Frontend:
- frontend/src/views/admin/ops/OpsDashboard.vue
- frontend/src/router/index.ts
- frontend/src/api/admin/ops.ts

Tests:
- backend/internal/handler/admin/ops_ws_handler_test.go (新增)
- backend/internal/service/ops_alert_service_integration_test.go (新增)

Docs:
- CHANGELOG.md (新增)
- docs/API-运维监控中心2.0.md (更新)

* fix(migrations): 修复calculate_health_score函数类型匹配问题

在ops_latest_metrics视图中添加显式类型转换,确保参数类型与函数签名匹配

* fix(lint): 修复golangci-lint检查发现的所有问题

- 将Redis依赖从service层移到repository层
- 添加错误检查(WebSocket连接和读取超时)
- 运行gofmt格式化代码
- 添加nil指针检查
- 删除未使用的alertService字段

修复问题:
- depguard: 3个(service层不应直接import redis)
- errcheck: 3个(未检查错误返回值)
- gofmt: 2个(代码格式问题)
- staticcheck: 4个(nil指针解引用)
- unused: 1个(未使用字段)

代码统计:
- 修改文件:11个
- 删除代码:490行
- 新增代码:105行
- 净减少:385行
2026-01-02 20:01:12 +08:00

1990 lines
62 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"regexp"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
)
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
)
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
)
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
"x-stainless-retry-count": true,
"x-stainless-timeout": true,
"x-stainless-lang": true,
"x-stainless-package-version": true,
"x-stainless-os": true,
"x-stainless-arch": true,
"x-stainless-runtime": true,
"x-stainless-runtime-version": true,
"x-stainless-helper-method": true,
"anthropic-dangerous-direct-browser-access": true,
"anthropic-version": true,
"x-app": true,
"anthropic-beta": true,
"accept-language": true,
"sec-fetch-mode": true,
"user-agent": true,
"content-type": true,
}
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
}
// NewGatewayService creates a new GatewayService
func NewGatewayService(
accountRepo AccountRepository,
groupRepo GroupRepository,
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
}
}
// GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
}
}
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用 system 内容
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
}
}
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
}
}
}
return ""
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var builder strings.Builder
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok {
_, _ = builder.WriteString(text)
}
}
}
}
}
}
systemText := builder.String()
// 检查 messages 中的 cacheable 内容
for _, msg := range parsed.Messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
return s.extractTextFromContent(msgMap["content"])
}
}
}
}
}
}
}
return systemText
}
func (s *GatewayService) extractTextFromSystem(system any) string {
switch v := system.(type) {
case string:
return v
case []any:
var texts []string
for _, part := range v {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
texts = append(texts, text)
}
}
}
return strings.Join(texts, "")
}
return ""
}
func (s *GatewayService) extractTextFromContent(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var texts []string
for _, part := range v {
if partMap, ok := part.(map[string]any); ok {
if partMap["type"] == "text" {
if text, ok := partMap["text"].(string); ok {
texts = append(texts, text)
}
}
}
}
return strings.Join(texts, "")
}
return ""
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
}
// replaceModelInBody 替换请求体中的model字段
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
req["model"] = newModel
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if hasForcePlatform && groupID != nil {
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err == nil {
return account, nil
}
// 分组中找不到,回退查询全部该平台账户
groupID = nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
// 2. 获取可调度账号列表(单平台)
var accounts []Account
var err error
if s.cfg.RunMode == config.RunModeSimple {
// 简易模式:忽略 groupID查询所有可用账号
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
}
// 2. 获取可调度账号列表
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
return apiKey, "apikey", nil
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
// Token刷新由后台 TokenRefreshService 处理此处只返回当前token
return accessToken, "oauth", nil
}
// 重试相关常量
const (
maxRetries = 10 // 最大重试次数
retryDelay = 3 * time.Second // 重试等待时间
)
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() {
return statusCode == 403
}
// API Key 账号:未配置的错误码重试
return !account.ShouldHandleErrorCode(statusCode)
}
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
if parsed == nil {
return nil, fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
if !parsed.HasSystem {
body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
})
}
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
return nil, err
}
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
_ = resp.Body.Close()
time.Sleep(retryDelay)
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
break
}
// 不需要重试(成功或不可重试的错误),跳出循环
break
}
defer func() { _ = resp.Body.Close() }()
// 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
// 处理可切换账号的错误
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
StatusCode: 403,
}
}
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
// OAuth账号应用统一指纹
var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err != nil {
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers
} else {
fingerprint = fp
// 2. 重写metadata.user_id需要指纹中的ClientID和账号的account_uuid
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
// We apply this for the main /v1/messages path as well as count_tokens.
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth账号应用缓存的指纹到请求头覆盖白名单透传的头
if fingerprint != nil {
s.identityService.ApplyFingerprint(req, fingerprint)
}
// 确保必要的headers存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
}
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
return clientBetaHeader
}
// 需要添加oauth beta
parts := strings.Split(clientBetaHeader, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
// 在claude-code-20250219后面插入oauth beta
claudeCodeIdx := -1
for i, p := range parts {
if p == claude.BetaClaudeCode {
claudeCodeIdx = i
break
}
}
if claudeCodeIdx >= 0 {
// 在claude-code后面插入
newParts := make([]string, 0, len(parts)+1)
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
newParts = append(newParts, claude.BetaOAuth)
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
return strings.Join(newParts, ",")
}
// 没有claude-code放在第一位
return claude.BetaOAuth + "," + clientBetaHeader
}
// 客户端没传,根据模型生成
// haiku 模型不需要 claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader
}
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader
}
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
var statusCode int
switch resp.StatusCode {
case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream authentication failed, please contact administrator"
case 403:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream access forbidden, please contact administrator"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded, please retry later"
case 529:
statusCode = http.StatusServiceUnavailable
errType = "overloaded_error"
errMsg = "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream service temporarily unavailable"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
// 返回自定义错误响应
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
}
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed after retries",
},
})
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
}
// streamingResult 流式响应结果
type streamingResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 透传其他响应头
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &ClaudeUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !sseDataRe.MatchString(line) {
return line
}
data := sseDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// 只替换 message_start 事件中的 message.model
if event["type"] != "message_start" {
return line
}
msg, ok := event["message"].(map[string]any)
if !ok {
return line
}
model, ok := msg["model"].(string)
if !ok || model != fromModel {
return line
}
msg["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens标准Claude API格式
var msgStart struct {
Type string `json:"type"`
Message struct {
Usage ClaudeUsage `json:"usage"`
} `json:"message"`
}
if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
}
// 解析message_delta获取tokens兼容GLM等把所有usage放在delta中的API
var msgDelta struct {
Type string `json:"type"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// output_tokens 总是从 message_delta 获取
usage.OutputTokens = msgDelta.Usage.OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取兼容GLM等API
if usage.InputTokens == 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解析usage
var response struct {
Usage ClaudeUsage `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
// 如果有模型映射替换响应中的model字段
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// 透传响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
return &response.Usage, nil
}
// replaceModelInResponseBody 替换响应体中的model字段
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
// 计算费用
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
// 获取费率倍数
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
}
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
if err != nil {
log.Printf("Calculate cost failed: %v", err)
// 使用默认费用继续
cost = &CostBreakdown{ActualCost: 0}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
log.Printf("Create usage log failed: %v", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
}
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
return err
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
case 429:
errMsg = "Rate limit exceeded"
case 529:
errMsg = "Service overloaded"
}
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 透传成功响应
c.Data(resp.StatusCode, "application/json", respBody)
return nil
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
// OAuth 账号:应用统一指纹和重写 userID
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传 headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
}
// 确保必要的 headers 存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil || len(accounts) == 0 {
return nil
}
// Filter by platform if specified
if platform != "" {
filtered := make([]Account, 0)
for _, acc := range accounts {
if acc.Platform == platform {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
// Collect unique models from all accounts
modelSet := make(map[string]struct{})
hasAnyMapping := false
for _, acc := range accounts {
mapping := acc.GetModelMapping()
if len(mapping) > 0 {
hasAnyMapping = true
for model := range mapping {
modelSet[model] = struct{}{}
}
}
}
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
return nil
}
// Convert to slice
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
return models
}