Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
This commit is contained in:
Edric Li
2026-01-01 16:15:16 +08:00
215 changed files with 22998 additions and 1641 deletions

View File

@@ -3,6 +3,7 @@ package config
import (
"fmt"
"strings"
"time"
"github.com/spf13/viper"
)
@@ -12,6 +13,20 @@ const (
RunModeSimple = "simple"
)
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
const (
// ConnectionPoolIsolationProxy: 按代理隔离
// 同一代理地址共享连接池,适合代理数量少、账户数量多的场景
ConnectionPoolIsolationProxy = "proxy"
// ConnectionPoolIsolationAccount: 按账户隔离
// 每个账户独立连接池,适合账户数量少、需要严格隔离的场景
ConnectionPoolIsolationAccount = "account"
// ConnectionPoolIsolationAccountProxy: 按账户+代理组合隔离(默认)
// 同一账户+代理组合共享连接池,提供最细粒度的隔离
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
@@ -29,6 +44,7 @@ type Config struct {
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
}
type GeminiOAuthConfig struct {
@@ -37,6 +53,17 @@ type GeminiOAuthConfig struct {
Scopes string `mapstructure:"scopes"`
}
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -79,12 +106,71 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// ConnectionPoolIsolation: 上游连接池隔离策略proxy/account/account_proxy
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
MaxIdleConns int `mapstructure:"max_idle_conns"`
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲0表示无限制
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
// MaxUpstreamClients: 上游连接池客户端最大缓存数量
// 当使用连接池隔离策略时,系统会为不同的账户/代理组合创建独立的 HTTP 客户端
// 此参数限制缓存的客户端数量,超出后会淘汰最久未使用的客户端
// 建议值:预估的活跃账户数 * 1.2(留有余量)
MaxUpstreamClients int `mapstructure:"max_upstream_clients"`
// ClientIdleTTLSeconds: 上游连接池客户端空闲回收阈值(秒)
// 超过此时间未使用的客户端会被标记为可回收
// 建议值:根据用户访问频率设置,一般 10-30 分钟
ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"`
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type GatewaySchedulingConfig struct {
// 粘性会话排队配置
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
}
func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
@@ -92,6 +178,15 @@ type DatabaseConfig struct {
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
MaxOpenConns int `mapstructure:"max_open_conns"`
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
MaxIdleConns int `mapstructure:"max_idle_conns"`
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"`
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"`
}
func (d *DatabaseConfig) DSN() string {
@@ -112,11 +207,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
)
}
// RedisConfig Redis 连接配置
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
// 连接池与超时配置(性能优化:可配置化连接池参数)
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
// PoolSize: 连接池大小,控制最大并发连接数
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
}
func (r *RedisConfig) Address() string {
@@ -203,12 +311,21 @@ func setDefaults() {
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
// Redis
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
@@ -240,6 +357,26 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数含活跃HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
@@ -254,6 +391,7 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
}
func (c *Config) Validate() error {
@@ -263,6 +401,86 @@ func (c *Config) Validate() error {
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
}
if c.Database.MaxIdleConns < 0 {
return fmt.Errorf("database.max_idle_conns must be non-negative")
}
if c.Database.MaxIdleConns > c.Database.MaxOpenConns {
return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns")
}
if c.Database.ConnMaxLifetimeMinutes < 0 {
return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative")
}
if c.Database.ConnMaxIdleTimeMinutes < 0 {
return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative")
}
if c.Redis.DialTimeoutSeconds <= 0 {
return fmt.Errorf("redis.dial_timeout_seconds must be positive")
}
if c.Redis.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("redis.read_timeout_seconds must be positive")
}
if c.Redis.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("redis.write_timeout_seconds must be positive")
}
if c.Redis.PoolSize <= 0 {
return fmt.Errorf("redis.pool_size must be positive")
}
if c.Redis.MinIdleConns < 0 {
return fmt.Errorf("redis.min_idle_conns must be non-negative")
}
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
default:
return fmt.Errorf("gateway.connection_pool_isolation must be one of: %s/%s/%s",
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
}
}
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
if c.Gateway.MaxIdleConnsPerHost <= 0 {
return fmt.Errorf("gateway.max_idle_conns_per_host must be positive")
}
if c.Gateway.MaxConnsPerHost < 0 {
return fmt.Errorf("gateway.max_conns_per_host must be non-negative")
}
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.MaxUpstreamClients <= 0 {
return fmt.Errorf("gateway.max_upstream_clients must be positive")
}
if c.Gateway.ClientIdleTTLSeconds <= 0 {
return fmt.Errorf("gateway.client_idle_ttl_seconds must be positive")
}
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
return nil
}

View File

@@ -1,6 +1,11 @@
package config
import "testing"
import (
"testing"
"time"
"github.com/spf13/viper"
)
func TestNormalizeRunMode(t *testing.T) {
tests := []struct {
@@ -21,3 +26,45 @@ func TestNormalizeRunMode(t *testing.T) {
}
}
}
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
}
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}

View File

@@ -10,15 +10,17 @@ import (
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
}
}
@@ -108,6 +110,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587
}
// Turnstile 参数验证
if req.TurnstileEnabled {
// 检查必填字段
if req.TurnstileSiteKey == "" {
response.BadRequest(c, "Turnstile Site Key is required when enabled")
return
}
if req.TurnstileSecretKey == "" {
response.BadRequest(c, "Turnstile Secret Key is required when enabled")
return
}
// 获取当前设置,检查参数是否有变化
currentSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 当 site_key 或 secret_key 任一变化时验证(避免配置错误导致无法登录)
siteKeyChanged := currentSettings.TurnstileSiteKey != req.TurnstileSiteKey
secretKeyChanged := currentSettings.TurnstileSecretKey != req.TurnstileSecretKey
if siteKeyChanged || secretKeyChanged {
if err := h.turnstileService.ValidateSecretKey(c.Request.Context(), req.TurnstileSecretKey); err != nil {
response.ErrorFrom(c, err)
return
}
}
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,

View File

@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
@@ -76,15 +80,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 解析请求获取模型名和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
@@ -106,7 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -124,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
platform := ""
@@ -133,6 +141,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini {
const maxAccountSwitches = 3
@@ -141,7 +153,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -150,35 +162,77 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, req.Model)
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -223,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -232,23 +286,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, req.Model)
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流
@@ -256,11 +349,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -525,6 +621,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
@@ -534,15 +634,18 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware2.GetSubscriptionFromContext(c)
@@ -554,17 +657,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return

View File

@@ -3,6 +3,7 @@ package handler
import (
"context"
"fmt"
"math/rand"
"net/http"
"time"
@@ -11,11 +12,28 @@ import (
"github.com/gin-gonic/gin"
)
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔100ms轮询并发槽位存在以下问题
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier = 1.5
// maxBackoff 最大退避时间
maxBackoff = 2 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
@@ -65,6 +83,16 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID)
}
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
@@ -108,7 +136,12 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Determine if ping is needed (streaming + ping format defined)
@@ -131,8 +164,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
pingCh = pingTicker.C
}
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
select {
@@ -156,7 +191,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
flusher.Flush()
case <-pollTicker.C:
case <-timer.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
@@ -174,6 +209,40 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff, rng)
timer.Reset(backoff)
}
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil此时不添加抖动
// 返回值下一次退避时间100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动jitter 范围 0.8 ~ 1.2
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff
}
if jittered > maxBackoff {
return maxBackoff
}
return jittered
}

View File

@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
return
}
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
@@ -191,14 +195,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -207,12 +216,48 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
account := selection.Account
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流)
@@ -225,6 +270,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {

View File

@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Read request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
@@ -76,6 +80,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// 验证 model 必填
if reqModel == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) {
@@ -136,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 {
@@ -146,14 +156,50 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return
}
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// Forward request
@@ -161,6 +207,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {

View File

@@ -0,0 +1,27 @@
package handler
import (
"errors"
"fmt"
"net/http"
)
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
return maxErr, true
}
return nil, false
}
func formatBodyLimit(limit int64) string {
const mb = 1024 * 1024
if limit >= mb {
return fmt.Sprintf("%dMB", limit/mb)
}
return fmt.Sprintf("%dB", limit)
}
func buildBodyTooLargeMessage(limit int64) string {
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
}

View File

@@ -0,0 +1,45 @@
package handler
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRequestBodyLimitTooLarge(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := int64(16)
router := gin.New()
router.Use(middleware.RequestBodyLimit(limit))
router.POST("/test", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": buildBodyTooLargeMessage(maxErr.Limit),
})
return
}
c.JSON(http.StatusBadRequest, gin.H{
"error": "read_failed",
})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
payload := bytes.Repeat([]byte("a"), int(limit+1))
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
}

View File

@@ -1,16 +0,0 @@
package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
}

View File

@@ -1,79 +0,0 @@
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -57,6 +57,7 @@ var geminiModels = []string{
"gemini-2.5-flash-lite",
"gemini-3-flash",
"gemini-3-pro-low",
"gemini-3-pro-high",
}
func TestMain(m *testing.M) {
@@ -641,6 +642,37 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
t.Logf("✅ thinking 模式工具调用测试通过, id=%v", result["id"])
}
// TestClaudeMessagesWithGeminiModel 测试在 Claude 端点使用 Gemini 模型
// 验证:通过 /v1/messages 端点传入 gemini 模型名的场景(含前缀映射)
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestClaudeMessagesWithGeminiModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Claude 端点调用 Gemini 模型
geminiViaClaude := []string{
"gemini-3-flash", // 直接支持
"gemini-3-pro-low", // 直接支持
"gemini-3-pro-high", // 直接支持
"gemini-3-pro", // 前缀映射 -> gemini-3-pro-high
"gemini-3-pro-preview", // 前缀映射 -> gemini-3-pro-high
}
for i, model := range geminiViaClaude {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Claude端点", func(t *testing.T) {
testClaudeMessage(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
testClaudeMessage(t, model, true)
})
}
}
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
// 验证Gemini 模型接受没有 signature 的 thinking block
func TestClaudeMessagesWithNoSignature(t *testing.T) {
@@ -738,3 +770,30 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
}
t.Logf("✅ 无 signature thinking 处理测试通过, id=%v", result["id"])
}
// TestGeminiEndpointWithClaudeModel 测试通过 Gemini 端点调用 Claude 模型
// 仅在 Antigravity 模式下运行ENDPOINT_PREFIX="/antigravity"
func TestGeminiEndpointWithClaudeModel(t *testing.T) {
if endpointPrefix != "/antigravity" {
t.Skip("仅在 Antigravity 模式下运行")
}
// 测试通过 Gemini 端点调用 Claude 模型
claudeViaGemini := []string{
"claude-sonnet-4-5",
"claude-opus-4-5-thinking",
}
for i, model := range claudeViaGemini {
if i > 0 {
time.Sleep(testInterval)
}
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
testGeminiGenerate(t, model, false)
})
time.Sleep(testInterval)
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
testGeminiGenerate(t, model, true)
})
}
}

View File

@@ -37,12 +37,26 @@ type ClaudeMetadata struct {
}
// ClaudeTool Claude 工具定义
// 支持两种格式:
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
type ClaudeTool struct {
Name string `json:"name"`
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
Name string `json:"name"`
Description string `json:"description,omitempty"` // 标准格式使用
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
}
// CustomToolSpec MCP custom 工具规格
type CustomToolSpec struct {
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`

View File

@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
"log"
"strings"
"github.com/google/uuid"
@@ -13,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking除非未来支持完整签名链路
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
@@ -30,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq)
reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
@@ -147,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block
parts = append([]GeminiPart{{
Text: "Thinking...",
Thought: true,
Text: "Thinking...",
Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...)
}
}
@@ -170,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
@@ -198,15 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case "thinking":
part := GeminiPart{
Text: block.Thinking,
Thought: true,
if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
// Claude 模型:仅在提供有效 signature 时保留 thinking block否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
parts = append(parts, part)
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image":
if block.Source != nil && block.Source.Type == "base64" {
@@ -231,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID,
},
}
// 保留原有 signature或对 Gemini 模型使用 dummy signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if allowDummyThought {
// 只有 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature避免验证问题
if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -378,13 +433,53 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for _, tool := range tools {
for i, tool := range tools {
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
continue
}
var description string
var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" {
if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
inputSchema = tool.InputSchema
}
// 清理 JSON Schema
params := cleanJSONSchema(tool.InputSchema)
params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
"type": "OBJECT",
"properties": map[string]any{},
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: tool.Description,
Description: description,
Parameters: params,
})
}
@@ -443,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{
"$schema": true,
"$id": true,
"$ref": true,
"additionalProperties": true,
"minLength": true,
"maxLength": true,
"minItems": true,
"maxItems": true,
"uniqueItems": true,
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"pattern": true,
"format": true,
"default": true,
"strict": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// 元 schema 字段
"$schema": true,
"$id": true,
"$ref": true,
// 字符串验证Gemini 不支持)
"minLength": true,
"maxLength": true,
"pattern": true,
// 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"minimum": true,
"maximum": true,
"exclusiveMinimum": true,
"exclusiveMaximum": true,
"multipleOf": true,
// 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schemaGemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
}
// cleanSchemaValue 递归清理 schema 值
@@ -487,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}

View File

@@ -0,0 +1,179 @@
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}

View File

@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",

View File

@@ -0,0 +1,157 @@
// Package httpclient 提供共享 HTTP 客户端池
//
// 性能优化说明:
// 原实现在多个服务中重复创建 http.Client
// 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端
//
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
// 4. 支持严格代理模式(代理失败则返回错误)
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
)
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
}
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
if client, ok := cached.(*http.Client); ok {
return client, nil
}
}
client, err := buildClient(opts)
if err != nil {
if opts.ProxyStrict {
return nil, err
}
fallback := opts
fallback.ProxyURL = ""
client, _ = buildClient(fallback)
}
actual, _ := sharedClients.LoadOrStore(key, client)
if c, ok := actual.(*http.Client); ok {
return c, nil
}
return client, nil
}
func buildClient(opts Options) (*http.Client, error) {
transport, err := buildTransport(opts)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}, nil
}
func buildTransport(opts Options) (*http.Transport, error) {
// 使用自定义值或默认值
maxIdleConns := opts.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
transport.Proxy = http.ProxyURL(parsed)
case "socks5", "socks5h":
dialer, err := proxy.FromURL(parsed, proxy.Direct)
if err != nil {
return nil, err
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}

View File

@@ -4,7 +4,7 @@ import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
)

View File

@@ -9,7 +9,7 @@ import (
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
err: errors2.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) {
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
err: errors2.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) {
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
Message: errors2.UnknownMessage,
},
},
}

View File

@@ -14,6 +14,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"time"
@@ -56,7 +57,7 @@ func newAccountRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *accoun
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
if account == nil {
return nil
return service.ErrAccountNilInput
}
builder := r.client.Account.Create().
@@ -98,7 +99,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
created, err := builder.Save(ctx)
if err != nil {
return err
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.ID = created.ID
@@ -231,11 +232,32 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
_, err := r.client.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx)
return err
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
@@ -393,25 +415,49 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
if _, err := r.client.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client
txClient = r.client
}
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(accountID)).Exec(ctx); err != nil {
return err
}
if len(groupIDs) == 0 {
if tx != nil {
return tx.Commit()
}
return nil
}
builders := make([]*dbent.AccountGroupCreate, 0, len(groupIDs))
for i, groupID := range groupIDs {
builders = append(builders, r.client.AccountGroup.Create().
builders = append(builders, txClient.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(i+1),
)
}
_, err := r.client.AccountGroup.CreateBulk(builders...).Save(ctx)
return err
if _, err := txClient.AccountGroup.CreateBulk(builders...).Save(ctx); err != nil {
return err
}
if tx != nil {
return tx.Commit()
}
return nil
}
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
@@ -555,24 +601,30 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
return nil
}
accountExtra, err := r.client.Account.Query().
Where(dbaccount.IDEQ(id)).
Select(dbaccount.FieldExtra).
Only(ctx)
// 使用 JSONB 合并操作实现原子更新,避免读-改-写的并发丢失更新问题
payload, err := json.Marshal(updates)
if err != nil {
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
return err
}
extra := normalizeJSONMap(accountExtra.Extra)
for k, v := range updates {
extra[k] = v
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
)
if err != nil {
return err
}
_, err = r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
SetExtra(extra).
Save(ctx)
return err
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {

View File

@@ -311,19 +311,20 @@ func groupEntityToService(g *dbent.Group) *service.Group {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
DefaultValidityDays: g.DefaultValidityDays,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}

View File

@@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
Impersonate: true,
})
}
func prefix(s string, n int) string {

View File

@@ -6,9 +6,9 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)

View File

@@ -2,68 +2,95 @@ package repository
import (
"context"
"errors"
"fmt"
"time"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 并发控制缓存常量定义
//
// 性能优化说明:
// 原实现使用 SCAN 命令遍历独立的槽位键concurrency:account:{id}:{requestID}
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
//
// 新实现改用 Redis 有序集合Sorted Set
// 1. 每个账号/用户只有一个键,成员为 requestID分数为时间戳
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
// 4. 单次 Redis 调用完成计数,减少网络往返
const (
// Key prefixes for independent slot keys
// Format: concurrency:account:{accountID}:{requestID}
// 并发槽位键前缀(有序集合)
// 格式: concurrency:account:{accountID}
accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID}
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
)
var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// acquireScript 使用有序集合计数并在未达上限时添加槽位
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
// ARGV[2] = TTL(秒)
// ARGV[3] = requestID
acquireScript = redis.NewScript(`
local pattern = KEYS[1]
local slotKey = KEYS[2]
local key = KEYS[1]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local requestID = ARGV[3]
-- Count existing slots using SCAN
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Check if we can acquire a slot
-- 清理过期槽位
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查是否已存在(支持重试场景刷新时间戳)
local exists = redis.call('ZSCORE', key, requestID)
if exists ~= false then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
-- 检查是否达到并发上限
local count = redis.call('ZCARD', key)
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
return 0
`)
// getCountScript counts slots using SCAN
// KEYS[1] = pattern for SCAN
// getCountScript 统计有序集合中的槽位数量并清理过期条目
// 使用 Redis TIME 命令获取服务器时间
// KEYS[1] = 有序集合键
// ARGV[1] = TTL
getCountScript = redis.NewScript(`
local pattern = KEYS[1]
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
-- 使用 Redis 服务器时间
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
@@ -89,55 +116,138 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
return 1
`)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
)
type concurrencyCache struct {
rdb *redis.Client
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
}
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
// waitQueueTTLSeconds: 等待队列过期时间0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
}
}
// Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
func accountSlotKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
func userSlotKey(userID int64) string {
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
@@ -145,13 +255,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
key := accountSlotKey(accountID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
pattern := accountSlotPattern(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
@@ -161,10 +272,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
@@ -172,13 +282,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
key := userSlotKey(userID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
pattern := userSlotPattern(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
@@ -189,7 +300,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
if err != nil {
return false, err
}
@@ -201,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}

View File

@@ -0,0 +1,135 @@
package repository
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const benchSlotTTLMinutes = 15
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func BenchmarkAccountConcurrency(b *testing.B) {
rdb := newBenchmarkRedisClient(b)
defer func() {
_ = rdb.Close()
}()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {
size := size
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
key := accountSlotKey(accountID)
b.StopTimer()
members := make([]redis.Z, 0, size)
now := float64(time.Now().Unix())
for i := 0; i < size; i++ {
members = append(members, redis.Z{
Score: now,
Member: fmt.Sprintf("req_%d", i),
})
}
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
b.Fatalf("初始化有序集合失败: %v", err)
}
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
b.Fatalf("设置有序集合 TTL 失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
b.Fatalf("获取并发数量失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, key).Err(); err != nil {
b.Fatalf("清理有序集合失败: %v", err)
}
})
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
keys := make([]string, 0, size)
b.StopTimer()
pipe := rdb.Pipeline()
for i := 0; i < size; i++ {
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
keys = append(keys, key)
pipe.Set(ctx, key, "1", benchSlotTTL)
}
if _, err := pipe.Exec(ctx); err != nil {
b.Fatalf("初始化扫描键失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
b.Fatalf("SCAN 计数失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
b.Fatalf("清理扫描键失败: %v", err)
}
})
}
}
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
var cursor uint64
count := 0
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return 0, err
}
count += len(keys)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
return count, nil
}
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
b.Helper()
redisURL := os.Getenv("TEST_REDIS_URL")
if redisURL == "" {
b.Skip("未设置 TEST_REDIS_URL跳过 Redis 基准测试")
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
}
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
b.Fatalf("Redis 连接失败: %v", err)
}
return client
}

View File

@@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/suite"
)
// 测试用 TTL 配置15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
@@ -212,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
@@ -226,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -0,0 +1,32 @@
package repository
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}

View File

@@ -1,6 +1,6 @@
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure
package repository
import (
"context"
@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
if err != nil {
return nil, nil, err
}
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源source of truth

View File

@@ -1,15 +1,35 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq"
)
// clientFromContext 从 context 中获取事务 client如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。

View File

@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
}
func createGeminiReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
})
}

View File

@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
client := req.C().SetTimeout(30 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
})
}
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {

View File

@@ -9,6 +9,7 @@ import (
"os"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -17,10 +18,14 @@ type githubReleaseClient struct {
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &githubReleaseClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
httpClient: sharedClient,
}
}
@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
if err != nil {
return err
}

View File

@@ -42,7 +42,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD)
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
created, err := builder.Save(ctx)
if err == nil {
@@ -79,6 +80,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
@@ -89,7 +91,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return err
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
@@ -239,8 +241,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分未找到与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 FOR UPDATE", id)
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
if err != nil {
return nil, err
}
@@ -263,7 +265,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1", id)
// 只查询未软删除的订阅,避免通知已取消订阅的用户
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
if err != nil {
return nil, err
}
@@ -282,7 +285,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
if _, err := exec.ExecContext(ctx, "DELETE FROM user_subscriptions WHERE group_id = $1", id); err != nil {
// 软删除订阅:设置 deleted_at 而非硬删除
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
return nil, err
}
}
@@ -297,18 +301,11 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
// 3. Remove the group id from users.allowed_groups array (legacy representation).
// Phase 1 compatibility: also delete from user_allowed_groups join table when present.
// 3. Remove the group id from user_allowed_groups join table.
// Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
if _, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1) WHERE $1 = ANY(allowed_groups)",
id,
); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {

View File

@@ -478,3 +478,58 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}

View File

@@ -1,67 +1,604 @@
package repository
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// httpUpstreamService is a generic HTTP upstream service that can be used for
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
// 默认配置常量
// 这些值在配置文件未指定时作为回退默认值使用
const (
// directProxyKey: 无代理时的缓存键标识
directProxyKey = "direct"
// defaultMaxIdleConns: 默认最大空闲连接总数
// HTTP/2 场景下单连接可多路复用240 足以支撑高并发
defaultMaxIdleConns = 240
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
defaultMaxIdleConnsPerHost = 120
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间5分钟
// 超时后连接会被关闭,释放系统资源
defaultIdleConnTimeout = 300 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间5分钟
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
// 超出后会淘汰最久未使用的客户端
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值15分钟
defaultClientIdleTTLSeconds = 900
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
// poolSettings 连接池配置参数
// 封装 Transport 所需的各项连接池参数
type poolSettings struct {
maxIdleConns int // 最大空闲连接总数
maxIdleConnsPerHost int // 每主机最大空闲连接数
maxConnsPerHost int // 每主机最大连接数(含活跃)
idleConnTimeout time.Duration // 空闲连接超时时间
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
client *http.Client // HTTP 客户端实例
proxyKey string // 代理标识(用于检测代理变更)
poolKey string // 连接池配置标识(用于检测配置变更)
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
}
// httpUpstreamService 通用 HTTP 上游服务
// 用于向任意 HTTP APIClaude、OpenAI 等)发送请求,支持可选代理
//
// 架构设计:
// - 根据隔离策略proxy/account/account_proxy缓存客户端实例
// - 每个客户端拥有独立的 Transport 连接池
// - 支持 LRU + 空闲时间双重淘汰策略
//
// 性能优化:
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
// 4. 达到最大连接数后等待可用连接,而非无限创建
// 5. 仅回收空闲客户端,避免中断活跃请求
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
// 7. 代理变更时清空旧连接池,避免复用错误代理
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
type httpUpstreamService struct {
defaultClient *http.Client
cfg *config.Config
cfg *config.Config // 全局配置
mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池key 由隔离策略决定
}
// NewHTTPUpstream creates a new generic HTTP upstream service
// NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
//
// 参数:
// - cfg: 全局配置,包含连接池参数和隔离策略
//
// 返回:
// - service.HTTPUpstream 接口实现
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &httpUpstreamService{
defaultClient: &http.Client{Transport: transport},
cfg: cfg,
cfg: cfg,
clients: make(map[string]*upstreamClientEntry),
}
}
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" {
return s.defaultClient.Do(req)
}
client := s.createProxyClient(proxyURL)
return client.Do(req)
}
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
parsedURL, err := url.Parse(proxyURL)
// Do 执行 HTTP 请求
// 根据隔离策略获取或创建客户端,并跟踪请求生命周期
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID用于账户级隔离
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应Body 已包装,关闭时自动更新计数)
// - error: 请求错误
//
// 注意:
// - 调用方必须关闭 resp.Body否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
return s.defaultClient
return nil, err
}
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE在完全读取前不会被淘汰
resp.Body = wrapTrackedBody(resp.Body, func() {
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
})
return &http.Client{Transport: transport}
return resp, nil
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
}
// getOrCreateClient 获取或创建客户端
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
//
// 参数:
// - proxyURL: 代理地址
// - accountID: 账户 ID
// - accountConcurrency: 账户并发限制
//
// 返回:
// - *upstreamClientEntry: 客户端缓存条目
//
// 隔离策略说明:
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
// 构建连接池配置键(用于检测配置变更)
poolKey := s.buildPoolKey(isolation, accountConcurrency)
now := time.Now()
nowUnix := now.UnixNano()
// 读锁快速路径:命中缓存直接返回,减少锁竞争
s.mu.RLock()
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
return entry, nil
}
s.mu.RUnlock()
// 写锁慢路径:创建或重建客户端
s.mu.Lock()
if entry, ok := s.clients[cacheKey]; ok {
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
return entry, nil
}
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.StoreInt64(&entry.inFlight, 1)
}
s.clients[cacheKey] = entry
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry, nil
}
// shouldReuseEntry 判断缓存条目是否可复用
// 若代理或连接池配置发生变化,则需要重建客户端
func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool {
if entry == nil {
return false
}
if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey {
return false
}
if entry.poolKey != poolKey {
return false
}
return true
}
// removeClientLocked 移除客户端(需持有锁)
// 从缓存中删除并关闭空闲连接
//
// 参数:
// - key: 缓存键
// - entry: 客户端条目
func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) {
delete(s.clients, key)
if entry != nil && entry.client != nil {
// 关闭空闲连接,释放系统资源
// 注意:这不会中断活跃连接
entry.client.CloseIdleConnections()
}
}
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
//
// 参数:
// - now: 当前时间
func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
ttl := s.clientIdleTTL()
if ttl <= 0 {
return
}
// 计算淘汰截止时间
cutoff := now.Add(-ttl).UnixNano()
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
// 淘汰超时的空闲客户端
if atomic.LoadInt64(&entry.lastUsed) <= cutoff {
s.removeClientLocked(key, entry)
}
}
}
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
var (
oldestKey string
oldestEntry *upstreamClientEntry
oldestTime int64
)
// 查找最久未使用且无活跃请求的客户端
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
lastUsed := atomic.LoadInt64(&entry.lastUsed)
if oldestEntry == nil || lastUsed < oldestTime {
oldestKey = key
oldestEntry = entry
oldestTime = lastUsed
}
}
// 所有客户端都有活跃请求,无法淘汰
if oldestEntry == nil {
return false
}
s.removeClientLocked(oldestKey, oldestEntry)
return true
}
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
func (s *httpUpstreamService) evictOverLimitLocked() bool {
maxClients := s.maxUpstreamClients()
if maxClients <= 0 {
return false
}
evicted := false
// 循环淘汰直到满足数量限制
for len(s.clients) > maxClients {
if !s.evictOldestIdleLocked() {
return evicted
}
evicted = true
}
return evicted
}
// getIsolationMode 获取连接池隔离模式
// 从配置中读取,无效值回退到 account_proxy 模式
//
// 返回:
// - string: 隔离模式proxy/account/account_proxy
func (s *httpUpstreamService) getIsolationMode() string {
if s.cfg == nil {
return config.ConnectionPoolIsolationAccountProxy
}
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation))
if mode == "" {
return config.ConnectionPoolIsolationAccountProxy
}
switch mode {
case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy:
return mode
default:
return config.ConnectionPoolIsolationAccountProxy
}
}
// maxUpstreamClients 获取最大客户端缓存数量
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) maxUpstreamClients() int {
if s.cfg == nil {
return defaultMaxUpstreamClients
}
if s.cfg.Gateway.MaxUpstreamClients > 0 {
return s.cfg.Gateway.MaxUpstreamClients
}
return defaultMaxUpstreamClients
}
// clientIdleTTL 获取客户端空闲回收阈值
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) clientIdleTTL() time.Duration {
if s.cfg == nil {
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
if s.cfg.Gateway.ClientIdleTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second
}
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
// resolvePoolSettings 解析连接池配置
// 根据隔离策略和账户并发数动态调整连接池参数
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - poolSettings: 连接池配置
//
// 说明:
// - 账户隔离模式下,连接池大小与账户并发数对应
// - 这确保了单账户不会占用过多连接资源
func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings {
settings := defaultPoolSettings(s.cfg)
// 账户隔离模式下,根据账户并发数调整连接池大小
if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 {
settings.maxIdleConns = accountConcurrency
settings.maxIdleConnsPerHost = accountConcurrency
settings.maxConnsPerHost = accountConcurrency
}
return settings
}
// buildPoolKey 构建连接池配置键
// 用于检测配置变更,配置变更时需要重建客户端
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - string: 配置键
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
return fmt.Sprintf("account:%d", accountConcurrency)
}
}
return "default"
}
// buildCacheKey 构建客户端缓存键
// 根据隔离策略决定缓存键的组成
//
// 参数:
// - isolation: 隔离模式
// - proxyKey: 代理标识
// - accountID: 账户 ID
//
// 返回:
// - string: 缓存键
//
// 缓存键格式:
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
switch isolation {
case config.ConnectionPoolIsolationAccount:
return fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
return fmt.Sprintf("proxy:%s", proxyKey)
}
}
// normalizeProxyURL 标准化代理 URL
// 处理空值和解析错误,返回标准化的键和解析后的 URL
//
// 参数:
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct"
// - *url.URL: 解析后的 URL空或解析失败返回 nil
func normalizeProxyURL(raw string) (string, *url.URL) {
proxyURL := strings.TrimSpace(raw)
if proxyURL == "" {
return directProxyKey, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return directProxyKey, nil
}
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
parsed.RawPath = ""
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.ForceQuery = false
if hostname := parsed.Hostname(); hostname != "" {
port := parsed.Port()
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
port = ""
}
hostname = strings.ToLower(hostname)
if port != "" {
parsed.Host = net.JoinHostPort(hostname, port)
} else {
parsed.Host = hostname
}
}
return parsed.String(), parsed
}
// defaultPoolSettings 获取默认连接池配置
// 从全局配置中读取,无效值使用常量默认值
//
// 参数:
// - cfg: 全局配置
//
// 返回:
// - poolSettings: 连接池配置
func defaultPoolSettings(cfg *config.Config) poolSettings {
maxIdleConns := defaultMaxIdleConns
maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
maxConnsPerHost := defaultMaxConnsPerHost
idleConnTimeout := defaultIdleConnTimeout
responseHeaderTimeout := defaultResponseHeaderTimeout
if cfg != nil {
if cfg.Gateway.MaxIdleConns > 0 {
maxIdleConns = cfg.Gateway.MaxIdleConns
}
if cfg.Gateway.MaxIdleConnsPerHost > 0 {
maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost
}
if cfg.Gateway.MaxConnsPerHost >= 0 {
maxConnsPerHost = cfg.Gateway.MaxConnsPerHost
}
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
}
if cfg.Gateway.ResponseHeaderTimeout > 0 {
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
}
}
return poolSettings{
maxIdleConns: maxIdleConns,
maxIdleConnsPerHost: maxIdleConnsPerHost,
maxConnsPerHost: maxConnsPerHost,
idleConnTimeout: idleConnTimeout,
responseHeaderTimeout: responseHeaderTimeout,
}
}
// buildUpstreamTransport 构建上游请求的 Transport
// 使用配置文件中的连接池参数,支持生产环境调优
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URLnil 表示直连)
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
//
// Transport 参数说明:
// - MaxIdleConns: 所有主机的最大空闲连接总数
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
MaxConnsPerHost: settings.maxConnsPerHost,
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
if proxyURL != nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
return transport
}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
io.ReadCloser // 原始响应体
once sync.Once
onClose func() // 关闭时的回调函数
}
// Close 关闭响应体并执行回调
// 使用 sync.Once 确保回调只执行一次
func (b *trackedBody) Close() error {
err := b.ReadCloser.Close()
if b.onClose != nil {
b.once.Do(b.onClose)
}
return err
}
// wrapTrackedBody 包装响应体以跟踪关闭事件
// 用于在响应体关闭时更新 inFlight 计数
//
// 参数:
// - body: 原始响应体
// - onClose: 关闭时的回调函数
//
// 返回:
// - io.ReadCloser: 包装后的响应体
func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser {
if body == nil {
return body
}
return &trackedBody{ReadCloser: body, onClose: onClose}
}

View File

@@ -0,0 +1,66 @@
package repository
import (
"net/http"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
// 这是 Go 基准测试的常见模式,确保测试结果准确
var httpClientSink *http.Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
//
// 测试目的:
// - 验证连接池复用相比每次新建的性能提升
// - 量化内存分配差异
//
// 预期结果:
// - "复用" 子测试应显著快于 "新建"
// - "复用" 子测试应零内存分配
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 创建测试配置
cfg := &config.Config{
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
}
upstream := NewHTTPUpstream(cfg)
svc, ok := upstream.(*httpUpstreamService)
if !ok {
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
}
proxyURL := "http://127.0.0.1:8080"
b.ReportAllocs() // 报告内存分配统计
// 子测试:每次新建客户端
// 模拟未优化前的行为,每次请求都创建新的 http.Client
b.Run("新建", func(b *testing.B) {
parsedProxy, err := url.Parse(proxyURL)
if err != nil {
b.Fatalf("解析代理地址失败: %v", err)
}
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
httpClientSink = &http.Client{
Transport: buildUpstreamTransport(settings, parsedProxy),
}
}
})
// 子测试:复用已缓存的客户端
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1)
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {
// 直接使用缓存的客户端,无内存分配
httpClientSink = client
}
})
}

View File

@@ -4,6 +4,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
@@ -12,45 +13,86 @@ import (
"github.com/stretchr/testify/suite"
)
// HTTPUpstreamSuite HTTP 上游服务测试套件
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
type HTTPUpstreamSuite struct {
suite.Suite
cfg *config.Config
cfg *config.Config // 测试用配置
}
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
}
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
// newService 创建测试用的 httpUpstreamService 实例
// 返回具体类型以便访问内部状态进行断言
func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
return svc
}
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
// 验证配置值能正确应用到 Transport
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
got := svc.createProxyClient("://bad-proxy-url")
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
// 验证解析失败时回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
svc := s.newService()
entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
// 验证超限且无可淘汰条目时返回错误
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 1,
}
svc := s.newService()
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
require.NoError(s.T(), err, "expected first acquire to succeed")
require.NotNil(s.T(), entry1, "expected entry")
entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
require.Error(s.T(), err, "expected error when cache limit reached")
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
}
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
@@ -60,17 +102,21 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道
seen := make(chan string, 1)
// 创建模拟代理服务器
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied")
}))
s.T().Cleanup(proxySrv.Close)
@@ -78,14 +124,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
up := NewHTTPUpstream(s.cfg)
// 发送请求到外部地址,应通过代理
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, proxySrv.URL)
resp, err := up.Do(req, proxySrv.URL, 1, 1)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "proxied", string(b), "unexpected body")
// 验证代理收到的是绝对 URI 格式HTTP 代理规范要求)
select {
case uri := <-seen:
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
@@ -94,6 +142,8 @@ func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
}
}
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
@@ -103,13 +153,134 @@ func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do with empty proxy")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct-empty", string(b))
}
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
// 验证不同账户使用独立的连接池
func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
// 验证同一账户使用不同代理时创建独立连接池
func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
// 验证账户切换代理时清理旧连接池,避免复用错误代理
func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
}
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
// 验证账户隔离模式下,连接池大小与账户并发数对应
func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
entry := svc.getOrCreateClient("", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch")
}
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
// 验证未指定并发数时使用全局配置值
func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
MaxIdleConns: 77,
MaxIdleConnsPerHost: 55,
MaxConnsPerHost: 66,
}
svc := s.newService()
// 账户并发数为 0应使用全局配置
entry := svc.getOrCreateClient("", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch")
require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch")
}
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
// 验证优先淘汰最久未使用的空闲客户端
func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 2, // 最多缓存 2 个客户端
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
}
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
// 验证有进行中请求的客户端不会被空闲超时淘汰
func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
entry1 := svc.getOrCreateClient("", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
// TestHTTPUpstreamSuite 运行测试套件
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
for _, entry := range svc.clients {
if entry == target {
return true
}
}
return false
}

View File

@@ -17,7 +17,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open sql db: %v", err)
os.Exit(1)
}
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err)
os.Exit(1)
}
@@ -330,7 +329,8 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {

View File

@@ -1,4 +1,4 @@
package infrastructure
package repository
import (
"context"
@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
}
continue // 迁移已应用且校验和匹配,跳过
}

View File

@@ -7,7 +7,6 @@ import (
"database/sql"
"testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require"
)
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set.
var applied int
@@ -53,6 +52,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {

View File

@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func createOpenAIReqClient(proxyURL string) *req.Client {
client := req.C().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
})
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
}
func NewPricingRemoteClient() service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
httpClient: sharedClient,
}
}

View File

@@ -2,18 +2,14 @@ package repository
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
"golang.org/x/net/proxy"
)
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
@@ -27,14 +23,14 @@ type proxyProbeService struct {
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL)
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
InsecureSkipVerify: true,
ProxyStrict: true,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
startTime := time.Now()
@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
Country: ipInfo.Country,
}, latencyMs, nil
}
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
_, err := createProxyTransport("://bad")
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
_, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL")
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1")
func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
_, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {

View File

@@ -178,7 +178,7 @@ func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}

View File

@@ -168,7 +168,8 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
affected, err := r.client.RedeemCode.Update().
client := clientFromContext(ctx, r.client)
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed).
SetUsedBy(userID).

View File

@@ -0,0 +1,39 @@
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"fmt"
"strings"
"sync"
"time"
"github.com/imroc/req/v3"
)
// reqClientOptions 定义 req 客户端的构建参数
type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
//
// 性能优化说明:
// 原实现在每次 OAuth 刷新时都创建新的 req.Client
// 1. claude_oauth_service.go: 每次刷新创建新客户端
// 2. openai_oauth_service.go: 每次刷新创建新客户端
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
//
// 新实现使用 sync.Map 缓存客户端:
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
// 2. 复用底层连接池,减少 TLS 握手开销
// 3. LoadOrStore 保证并发安全,避免重复创建
var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
return c
}
}
client := req.C().SetTimeout(opts.Timeout)
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
return c
}
return client
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
)
}

View File

@@ -105,3 +105,59 @@ func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
s.Require().NoError(err)
s.Require().Equal("new_val", got2)
}
// TestSet_EmptyValue 测试保存空字符串值
// 这是一个回归测试确保可选设置如站点Logo、API端点地址等可以保存为空字符串
func (s *SettingRepoSuite) TestSet_EmptyValue() {
// 测试 Set 方法保存空值
s.Require().NoError(s.repo.Set(s.ctx, "empty_key", ""), "Set with empty value should succeed")
got, err := s.repo.GetValue(s.ctx, "empty_key")
s.Require().NoError(err, "GetValue for empty value")
s.Require().Equal("", got, "empty value should be preserved")
}
// TestSetMultiple_WithEmptyValues 测试批量保存包含空字符串的设置
// 模拟用户保存站点设置时部分字段为空的场景
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
// 模拟保存站点设置,部分字段有值,部分字段为空
settings := map[string]string{
"site_name": "AICodex2API",
"site_subtitle": "Subscription to API",
"site_logo": "", // 用户未上传Logo
"api_base_url": "", // 用户未设置API地址
"contact_info": "", // 用户未设置联系方式
"doc_url": "", // 用户未设置文档链接
}
s.Require().NoError(s.repo.SetMultiple(s.ctx, settings), "SetMultiple with empty values should succeed")
// 验证所有值都正确保存
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
s.Require().Equal("AICodex2API", result["site_name"])
s.Require().Equal("Subscription to API", result["site_subtitle"])
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
s.Require().Equal("", result["contact_info"], "empty contact_info should be preserved")
s.Require().Equal("", result["doc_url"], "empty doc_url should be preserved")
}
// TestSetMultiple_UpdateToEmpty 测试将已有值更新为空字符串
// 确保用户可以清空之前设置的值
func (s *SettingRepoSuite) TestSetMultiple_UpdateToEmpty() {
// 先设置非空值
s.Require().NoError(s.repo.Set(s.ctx, "clearable_key", "initial_value"))
got, err := s.repo.GetValue(s.ctx, "clearable_key")
s.Require().NoError(err)
s.Require().Equal("initial_value", got)
// 更新为空值
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"clearable_key": ""}), "Update to empty should succeed")
got, err = s.repo.GetValue(s.ctx, "clearable_key")
s.Require().NoError(err)
s.Require().Equal("", got, "value should be updated to empty string")
}

View File

@@ -7,10 +7,12 @@ import (
"fmt"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -111,3 +113,104 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
}
// --- UserSubscription 软删除测试 ---
func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group {
t.Helper()
g, err := client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err, "create ent group")
return g
}
func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription")
_, err := repo.GetByID(ctx, sub.ID)
require.Error(t, err, "deleted rows should be hidden by default")
_, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(sub.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "first delete")
require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com")
g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a"))
g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b"))
repo := NewUserSubscriptionRepository(client)
sub1 := &service.UserSubscription{
UserID: u.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub1), "create subscription 1")
sub2 := &service.UserSubscription{
UserID: u.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub2), "create subscription 2")
// 软删除 sub1
require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1")
// ListByUserID 应只返回未删除的订阅
subs, err := repo.ListByUserID(ctx, u.ID)
require.NoError(t, err, "ListByUserID")
require.Len(t, subs, 1, "should only return non-deleted subscriptions")
require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned")
}

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -20,11 +21,15 @@ type turnstileVerifier struct {
}
func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}
}
return &turnstileVerifier{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
verifyURL: turnstileVerifyURL,
httpClient: sharedClient,
verifyURL: turnstileVerifyURL,
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"time"
@@ -452,6 +453,176 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
return &stats, nil
}
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
//
// 性能优化说明:
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
// 1. 需要传输大量数据到应用层
// 2. 应用层循环计算增加 CPU 和内存开销
//
// 新实现使用 SQL 聚合函数:
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
// 2. 只返回单行聚合结果,大幅减少数据传输量
// 3. 利用数据库索引优化聚合查询性能
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{accountID, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
// 性能优化:数据库层聚合计算,避免应用层循环统计
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
query := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE model = $1 AND created_at >= $2 AND created_at < $3
`
var stats usagestats.UsageStats
if err := scanSingleRow(
ctx,
r.sql,
query,
[]any{modelName, startTime, endTime},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return &stats, nil
}
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
tzName := resolveUsageStatsTimezone()
query := `
SELECT
-- 使用应用时区分组,避免数据库会话时区导致日边界偏移。
TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date,
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY 1
ORDER BY 1
`
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
result = nil
}
}()
result = make([]map[string]any, 0)
for rows.Next() {
var (
date string
totalRequests int64
totalInputTokens int64
totalOutputTokens int64
totalCacheTokens int64
totalCost float64
totalActualCost float64
avgDurationMs float64
)
if err = rows.Scan(
&date,
&totalRequests,
&totalInputTokens,
&totalOutputTokens,
&totalCacheTokens,
&totalCost,
&totalActualCost,
&avgDurationMs,
); err != nil {
return nil, err
}
result = append(result, map[string]any{
"date": date,
"total_requests": totalRequests,
"total_input_tokens": totalInputTokens,
"total_output_tokens": totalOutputTokens,
"total_cache_tokens": totalCacheTokens,
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
"total_cost": totalCost,
"total_actual_cost": totalActualCost,
"average_duration_ms": avgDurationMs,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。
// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。
func resolveUsageStatsTimezone() string {
tzName := timezone.Name()
if tzName != "" && tzName != "Local" {
return tzName
}
if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" {
return envTZ
}
return "UTC"
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
@@ -938,6 +1109,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
@@ -964,6 +1138,9 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
@@ -1006,6 +1183,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
today := timezone.Today()
todayQuery := `
@@ -1032,6 +1212,9 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}

View File

@@ -12,20 +12,18 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
return &userRepository{client: client}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
@@ -86,10 +84,11 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err == nil {
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
return out, nil
}
@@ -102,10 +101,11 @@ func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
@@ -240,11 +240,12 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err == nil {
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
if err != nil {
return nil, nil, err
}
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
@@ -252,12 +253,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
return err
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
n, err := r.client.User.Update().
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id), dbuser.BalanceGTE(amount)).
AddBalance(-amount).
Save(ctx)
@@ -271,8 +280,15 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
_, err := r.client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
return err
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
@@ -280,33 +296,14 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
exec := r.sql
if exec == nil {
// 未注入 sqlExecutor 时,退回到 ent client 的 ExecContext支持事务
exec = r.client
}
joinAffected, err := r.client.UserAllowedGroup.Delete().
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
arrayRes, err := exec.ExecContext(
ctx,
"UPDATE users SET allowed_groups = array_remove(allowed_groups, $1), updated_at = NOW() WHERE $1 = ANY(allowed_groups)",
groupID,
)
if err != nil {
return 0, err
}
arrayAffected, _ := arrayRes.RowsAffected()
if int64(joinAffected) > arrayAffected {
return int64(joinAffected), nil
}
return arrayAffected, nil
return int64(affected), nil
}
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
@@ -323,10 +320,11 @@ func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, erro
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err == nil {
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
@@ -356,8 +354,7 @@ func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64)
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 1) 以 user_allowed_groups 为读写源,确保新旧逻辑一致;
// 2) 额外更新 users.allowed_groups历史字段以保持兼容。
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
@@ -376,12 +373,10 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
unique[id] = struct{}{}
}
legacyGroups := make([]int64, 0, len(unique))
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
legacyGroups = append(legacyGroups, groupID)
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
@@ -392,16 +387,6 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
}
}
// Phase 1 兼容:保持 users.allowed_groups数组字段同步避免旧查询路径读取到过期数据。
var legacy any
if len(legacyGroups) > 0 {
sort.Slice(legacyGroups, func(i, j int) bool { return legacyGroups[i] < legacyGroups[j] })
legacy = pq.Array(legacyGroups)
}
if _, err := client.ExecContext(ctx, "UPDATE users SET allowed_groups = $1::bigint[] WHERE id = $2", legacy, userID); err != nil {
return err
}
return nil
}

View File

@@ -507,3 +507,24 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrInsufficientBalance 因为 WHERE 条件不匹配
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
}

View File

@@ -20,10 +20,11 @@ func NewUserSubscriptionRepository(client *dbent.Client) service.UserSubscriptio
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.Create().
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.Create().
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetExpiresAt(sub.ExpiresAt).
@@ -57,7 +58,8 @@ func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.Us
}
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(id)).
WithUser().
WithGroup().
@@ -70,7 +72,8 @@ func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*se
}
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
WithGroup().
Only(ctx)
@@ -81,7 +84,8 @@ func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context,
}
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
m, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
m, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.GroupIDEQ(groupID),
@@ -98,10 +102,11 @@ func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
if sub == nil {
return nil
return service.ErrSubscriptionNilInput
}
builder := r.client.UserSubscription.UpdateOneID(sub.ID).
client := clientFromContext(ctx, r.client)
builder := client.UserSubscription.UpdateOneID(sub.ID).
SetUserID(sub.UserID).
SetGroupID(sub.GroupID).
SetStartsAt(sub.StartsAt).
@@ -127,12 +132,14 @@ func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.Us
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
// Match GORM semantics: deleting a missing row is not an error.
_, err := r.client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.Delete().Where(usersubscription.IDEQ(id)).Exec(ctx)
return err
}
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID)).
WithGroup().
Order(dbent.Desc(usersubscription.FieldCreatedAt)).
@@ -144,7 +151,8 @@ func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID in
}
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.UserIDEQ(userID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
@@ -160,7 +168,8 @@ func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID))
total, err := q.Clone().Count(ctx)
if err != nil {
@@ -182,7 +191,8 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
q := r.client.UserSubscription.Query()
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
q = q.Where(usersubscription.UserIDEQ(*userID))
}
@@ -214,34 +224,39 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
return client.UserSubscription.Query().
Where(usersubscription.UserIDEQ(userID), usersubscription.GroupIDEQ(groupID)).
Exist(ctx)
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetExpiresAt(newExpiresAt).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetStatus(status).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
_, err := r.client.UserSubscription.UpdateOneID(subscriptionID).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(subscriptionID).
SetNotes(notes).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyWindowStart(start).
SetWeeklyWindowStart(start).
SetMonthlyWindowStart(start).
@@ -250,7 +265,8 @@ func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetDailyUsageUsd(0).
SetDailyWindowStart(newWindowStart).
Save(ctx)
@@ -258,7 +274,8 @@ func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetWeeklyUsageUsd(0).
SetWeeklyWindowStart(newWindowStart).
Save(ctx)
@@ -266,24 +283,54 @@ func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
client := clientFromContext(ctx, r.client)
_, err := client.UserSubscription.UpdateOneID(id).
SetMonthlyUsageUsd(0).
SetMonthlyWindowStart(newWindowStart).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
// IncrementUsage 原子性地累加订阅用量。
// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
// 此处仅负责记录实际消费,确保消费数据的完整性。
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
_, err := r.client.UserSubscription.UpdateOneID(id).
AddDailyUsageUsd(costUSD).
AddWeeklyUsageUsd(costUSD).
AddMonthlyUsageUsd(costUSD).
Save(ctx)
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
weekly_usage_usd = us.weekly_usage_usd + $1,
monthly_usage_usd = us.monthly_usage_usd + $1,
updated_at = NOW()
FROM groups g
WHERE us.id = $2
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
return nil
}
// affected == 0订阅不存在或已删除
return service.ErrSubscriptionNotFound
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
n, err := r.client.UserSubscription.Update().
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Update().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
@@ -296,7 +343,8 @@ func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
subs, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
subs, err := client.UserSubscription.Query().
Where(
usersubscription.StatusEQ(service.SubscriptionStatusActive),
usersubscription.ExpiresAtLTE(time.Now()),
@@ -309,12 +357,14 @@ func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service
}
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().Where(usersubscription.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.client.UserSubscription.Query().
client := clientFromContext(ctx, r.client)
count, err := client.UserSubscription.Query().
Where(
usersubscription.GroupIDEQ(groupID),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
@@ -325,7 +375,8 @@ func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
}
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
client := clientFromContext(ctx, r.client)
n, err := client.UserSubscription.Delete().Where(usersubscription.GroupIDEQ(groupID)).Exec(ctx)
return int64(n), err
}

View File

@@ -4,6 +4,7 @@ package repository
import (
"context"
"fmt"
"testing"
"time"
@@ -631,3 +632,116 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}
// --- 软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
group := s.mustCreateGroup("g-softdeleted")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 软删除分组
_, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx)
s.Require().NoError(err, "soft delete group")
// IncrementUsage 应该失败,因为分组已软删除
err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)
s.Require().Error(err, "should fail for soft-deleted group")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() {
err := s.repo.IncrementUsage(s.ctx, 999999, 1.0)
s.Require().Error(err, "should fail for non-existent subscription")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
// --- nil 入参测试 ---
func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() {
err := s.repo.Create(s.ctx, nil)
s.Require().Error(err, "Create should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
err := s.repo.Update(s.ctx, nil)
s.Require().Error(err, "Update should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
// --- 并发用量更新测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroup("g-concurrent")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10
const incrementPerGoroutine = 1.5
// 启动多个 goroutine 并发调用 IncrementUsage
errCh := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine)
}()
}
// 等待所有 goroutine 完成
for i := 0; i < numGoroutines; i++ {
err := <-errCh
s.Require().NoError(err, "IncrementUsage should succeed")
}
// 验证累加结果正确
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
expectedUsage := float64(numGoroutines) * incrementPerGoroutine
s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background())
s.Require().NoError(err, "begin tx")
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
txCtx := dbent.NewTxContext(context.Background(), tx)
suffix := fmt.Sprintf("%d", time.Now().UnixNano())
userEnt, err := tx.Client().User.Create().
SetEmail("tx-user-" + suffix + "@example.com").
SetPasswordHash("test").
Save(txCtx)
s.Require().NoError(err, "create user in tx")
groupEnt, err := tx.Client().Group.Create().
SetName("tx-group-" + suffix).
Save(txCtx)
s.Require().NoError(err, "create group in tx")
repo := NewUserSubscriptionRepository(baseClient)
sub := &service.UserSubscription{
UserID: userEnt.ID,
GroupID: groupEnt.ID,
ExpiresAt: time.Now().AddDate(0, 0, 30),
Status: service.SubscriptionStatusActive,
AssignedAt: time.Now(),
Notes: "tx",
}
s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx")
s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx")
s.Require().NoError(tx.Rollback(), "rollback tx")
tx = nil
_, err = repo.GetByID(context.Background(), sub.ID)
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}

View File

@@ -1,9 +1,30 @@
package repository
import (
"database/sql"
"errors"
entsql "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -20,7 +41,7 @@ var ProviderSet = wire.NewSet(
NewGatewayCache,
NewBillingCache,
NewApiKeyCache,
NewConcurrencyCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,
NewRedeemCache,
@@ -38,4 +59,58 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient,
NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -385,7 +385,7 @@ func newContractDeps(t *testing.T) *contractDeps {
authHandler := handler.NewAuthHandler(cfg, nil, userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -981,6 +981,18 @@ func (r *stubUsageLogRepo) GetApiKeyStatsAggregated(ctx context.Context, apiKeyI
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented")
}

View File

@@ -7,7 +7,7 @@ import (
"os"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)

View File

@@ -8,7 +8,7 @@ import (
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"

View File

@@ -0,0 +1,15 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
c.Next()
}
}

View File

@@ -18,8 +18,11 @@ func RegisterGatewayRoutes(
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
{
gateway.POST("/messages", h.Gateway.Messages)
@@ -32,6 +35,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
@@ -41,10 +45,11 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{
@@ -55,6 +60,7 @@ func RegisterGatewayRoutes(
}
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{

View File

@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
@@ -110,6 +141,28 @@ func (a *Account) GetCredential(key string) string {
}
}
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
// 兼容以下格式:
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
// - Unix 时间戳字符串: "1735689600"
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
func (a *Account) GetCredentialAsTime(key string) *time.Time {
s := a.GetCredential(key)
if s == "" {
return nil
}
// 尝试 RFC3339 格式
if t, err := time.Parse(time.RFC3339, s); err == nil {
return &t
}
// 尝试 Unix 时间戳(纯数字字符串)
if ts, err := strconv.ParseInt(s, 10, 64); err == nil {
t := time.Unix(ts, 0)
return &t
}
return nil
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
@@ -324,19 +377,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
if v, ok := a.Credentials["expires_at"].(float64); ok {
tt := time.Unix(int64(v), 0)
return &tt
}
return nil
}
return &t
return a.GetCredentialAsTime("expires_at")
}
func (a *Account) IsOpenAITokenExpired() bool {

View File

@@ -5,12 +5,13 @@ import (
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
)
type AccountRepository interface {

View File

@@ -12,7 +12,6 @@ import (
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
@@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Check if token needs refresh
needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt {
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
if time.Now().Add(5 * time.Minute).After(*expiresAt) {
needRefresh = true
}
}
@@ -263,7 +261,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -378,7 +376,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -449,7 +447,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}

View File

@@ -52,6 +52,9 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据utilization, resets_at
@@ -90,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -119,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
}
}
@@ -143,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
@@ -189,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -385,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return info
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}

View File

@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard
}
// 限额字段0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{
Name: input.Name,
Description: input.Description,
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
MonthlyLimitUSD: input.MonthlyLimitUSD,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段支持设置为nil清除限额或具体值
// 限额字段0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
if err := s.groupRepo.Update(ctx, group); err != nil {

View File

@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
// Antigravity 直接支持的模型
// Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true,
}
// Antigravity 系统默认模型映射表(不支持 → 支持
var antigravityModelMapping = map[string]string{
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
"claude-opus-4": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-haiku-4": "gemini-3-flash",
"claude-haiku-4-5": "gemini-3-flash",
"claude-3-haiku-20240307": "gemini-3-flash",
"claude-haiku-4-5-20251001": "gemini-3-flash",
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// 长前缀优先
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法
// 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
}
// 2. 系统默认映射
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
return mapped
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
// 2. 直接支持的模型透传
if antigravitySupportedModels[requestedModel] {
return requestedModel
}
// 4. Claude 前缀透传直接支持的模型
if antigravitySupportedModels[requestedModel] {
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview
for _, pm := range antigravityPrefixMapping {
if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
}
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
}
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// TestConnectionResult 测试连接结果
@@ -180,7 +172,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {
@@ -372,7 +373,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
@@ -515,7 +516,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)

View File

@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "claude-sonnet-4-5",
},
// 3. Gemini 透传

View File

@@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
// isTokenExpired 检查 token 是否过期
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
expiresAt := parseAntigravityExpiresAt(account)
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}

View File

@@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
}
// 2. 如果即将过期则刷新
expiresAt := parseAntigravityExpiresAt(account)
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
@@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err == nil && fresh != nil {
account = fresh
}
expiresAt = parseAntigravityExpiresAt(account)
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
@@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = parseAntigravityExpiresAt(account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
@@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string {
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}
func parseAntigravityExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}

View File

@@ -2,7 +2,7 @@ package service
import (
"context"
"strconv"
"fmt"
"time"
)
@@ -29,21 +29,22 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
}
// NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
timeUntilExpiry := time.Until(*expiresAt)
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < antigravityRefreshWindow
return needsRefresh
}
// Refresh 执行 token 刷新

View File

@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)

View File

@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"

View File

@@ -4,10 +4,12 @@ import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// 错误定义
@@ -27,6 +29,46 @@ type subscriptionCacheData struct {
Version int64
}
// 缓存写入任务类型
type cacheWriteKind int
const (
cacheWriteSetBalance cacheWriteKind = iota
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
)
// 异步缓存写入工作池配置
//
// 性能优化说明:
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
// 1. 每次请求创建新 goroutine高并发下产生大量短生命周期 goroutine
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
// 3. goroutine 创建/销毁带来额外开销
//
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine避免频繁创建销毁
// 2. 使用带缓冲的 channel1000作为任务队列平滑写入峰值
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池
const (
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
)
// cacheWriteTask 缓存写入任务
type cacheWriteTask struct {
kind cacheWriteKind
userID int64
groupID int64
balance float64
amount float64
subscriptionData *subscriptionCacheData
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
@@ -34,16 +76,151 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
cacheWriteDropClosedCount uint64
cacheWriteDropClosedLastLog int64
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
return &BillingCacheService{
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
cfg: cfg,
}
svc.startCacheWriteWorkers()
return svc
}
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false并记录告警
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
return false
}
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
select {
case s.cacheWriteChan <- task:
return true
default:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s.logCacheWriteDrop(task, "full")
return false
}
}
func (s *BillingCacheService) cacheWriteWorker() {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
s.setBalanceCache(ctx, task.userID, task.balance)
case cacheWriteSetSubscription:
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
case cacheWriteUpdateSubscriptionUsage:
if s.cache != nil {
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
}
}
case cacheWriteDeductBalance:
if s.cache != nil {
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
}
cancel()
}
}
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func cacheWriteKindName(kind cacheWriteKind) string {
switch kind {
case cacheWriteSetBalance:
return "set_balance"
case cacheWriteSetSubscription:
return "set_subscription"
case cacheWriteUpdateSubscriptionUsage:
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
default:
return "unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
var (
countPtr *uint64
lastPtr *int64
)
switch reason {
case "full":
countPtr = &s.cacheWriteDropFullCount
lastPtr = &s.cacheWriteDropFullLastLog
case "closed":
countPtr = &s.cacheWriteDropClosedCount
lastPtr = &s.cacheWriteDropClosedLastLog
default:
return
}
atomic.AddUint64(countPtr, 1)
now := time.Now().UnixNano()
last := atomic.LoadInt64(lastPtr)
if now-last < int64(cacheWriteDropLogInterval) {
return
}
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
return
}
dropped := atomic.SwapUint64(countPtr, 0)
if dropped == 0 {
return
}
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
cacheWriteKindName(task.kind),
task.userID,
task.groupID,
)
}
// ============================================
@@ -70,11 +247,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setBalanceCache(cacheCtx, userID, balance)
}()
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
return balance, nil
}
@@ -98,7 +275,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
}
}
// DeductBalanceCache 扣减余额缓存(步调用,用于扣费后更新缓存
// DeductBalanceCache 扣减余额缓存(步调用)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.cache == nil {
return nil
@@ -106,6 +283,26 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
return s.cache.DeductUserBalance(ctx, userID, amount)
}
// QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: userID,
amount: amount,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.cache == nil {
@@ -141,11 +338,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
}()
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetSubscription,
userID: userID,
groupID: groupID,
subscriptionData: data,
})
return data, nil
}
@@ -199,7 +397,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(步调用,用于扣费后更新缓存
// UpdateSubscriptionUsage 更新订阅用量缓存(步调用)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.cache == nil {
return nil
@@ -207,6 +405,27 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
}
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage,
userID: userID,
groupID: groupID,
amount: costUSD,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.cache == nil {

View File

@@ -0,0 +1,75 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheWorkerStub struct {
balanceUpdates int64
subscriptionUpdates int64
}
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
for i := 0; i < cacheWriteBufferSize*2; i++ {
svc.QueueDeductBalance(1, 1)
}
require.Less(t, time.Since(start), 2*time.Second)
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.balanceUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}

View File

@@ -9,24 +9,35 @@ import (
"time"
)
// ConcurrencyCache defines cache operations for concurrency service
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
// ConcurrencyCache 定义并发控制的缓存接口
// 使用有序集合存储槽位,按时间戳清理过期条目
type ConcurrencyCache interface {
// Account slot management - each slot is a separate key with independent TTL
// Key format: concurrency:account:{accountID}:{requestID}
// 账号槽位管理
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// User slot management - each slot is a separate key with independent TTL
// Key format: concurrency:user:{userID}:{requestID}
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue - uses counter with TTL set only on creation
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer)
}
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {

View File

@@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
type CRSSyncService struct {
@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return nil, errors.New("username and password are required")
}
client := &http.Client{Timeout: 20 * time.Second}
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil {

View File

@@ -91,6 +91,9 @@ const (
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)

View File

@@ -10,7 +10,7 @@ import (
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (

View File

@@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
@@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -783,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
})
}
}
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}

View File

@@ -0,0 +1,72 @@
package service
import (
"encoding/json"
"fmt"
)
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
// 原实现在多个位置重复解析请求体Handler、Service 各解析一次):
// 1. gateway_handler.go 解析获取 model 和 stream
// 2. gateway_service.go 再次解析获取 system、messages、metadata
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
//
// 新实现一次解析,多处复用:
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id用于会话亲和
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}

View File

@@ -0,0 +1,40 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseGatewayRequest(t *testing.T) {
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
require.True(t, parsed.Stream)
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
require.True(t, parsed.HasSystem)
require.NotNil(t, parsed.System)
require.Len(t, parsed.Messages, 1)
}
func TestParseGatewayRequest_SystemNull(t *testing.T) {
body := []byte(`{"model":"claude-3","system":null}`)
parsed, err := ParseGatewayRequest(body)
require.NoError(t, err)
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
require.True(t, parsed.HasSystem)
require.Nil(t, parsed.System)
}
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
body := []byte(`{"model":123}`)
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
body := []byte(`{"stream":"true"}`)
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}

View File

@@ -13,6 +13,7 @@ import (
"log"
"net/http"
"regexp"
"sort"
"strings"
"time"
@@ -33,7 +34,10 @@ const (
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataRe = regexp.MustCompile(`^data:\s*`)
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
)
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
@@ -64,6 +68,20 @@ type GatewayCache interface {
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
@@ -106,6 +124,7 @@ type GatewayService struct {
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
}
// NewGatewayService creates a new GatewayService
@@ -117,6 +136,7 @@ func NewGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
@@ -132,6 +152,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
@@ -141,40 +162,36 @@ func NewGatewayService(
}
}
// GenerateSessionHash 从请求计算粘性会话hash
func (s *GatewayService) GenerateSessionHash(body []byte) string {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
// GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
// 1. 最高优先级从metadata.user_id提取session_xxx
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
if match := re.FindStringSubmatch(userID); len(match) > 1 {
return match[1]
}
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
}
}
// 2. 提取带cache_control: {type: "ephemeral"}的内容
cacheableContent := s.extractCacheableContent(req)
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用system内容
if system := req["system"]; system != nil {
systemText := s.extractTextFromSystem(system)
// 3. Fallback: 使用 system 内容
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
}
}
// 4. 最后fallback: 使用第一条消息
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
if firstMsg, ok := messages[0].(map[string]any); ok {
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
@@ -185,36 +202,46 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
return ""
}
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
var content string
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
// 检查system中的cacheable内容
if system, ok := req["system"].([]any); ok {
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var builder strings.Builder
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok {
content += text
_, _ = builder.WriteString(text)
}
}
}
}
}
}
systemText := builder.String()
// 检查messages中的cacheable内容
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
// 找到cacheable内容提取第一条消息的文本
return s.extractTextFromContent(msgMap["content"])
}
// 检查 messages 中的 cacheable 内容
for _, msg := range parsed.Messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
return s.extractTextFromContent(msgMap["content"])
}
}
}
@@ -223,7 +250,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
}
}
return content
return systemText
}
func (s *GatewayService) extractTextFromSystem(system any) string {
@@ -332,8 +359,354 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -389,7 +762,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
@@ -419,6 +794,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
@@ -478,7 +854,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
@@ -515,24 +893,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
@@ -588,19 +952,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
// 解析请求获取model和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
if parsed == nil {
return nil, fmt.Errorf("parse request: empty request")
}
if !gjson.GetBytes(body, "system").Exists() {
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
if !parsed.HasSystem {
body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{
"type": "text",
@@ -613,13 +975,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 应用模型映射仅对apikey类型账号
originalModel := req.Model
originalModel := reqModel
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
req.Model = mappedModel
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
}
}
@@ -640,13 +1002,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
return nil, err
}
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
@@ -686,14 +1048,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
if req.Stream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -705,7 +1091,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
@@ -715,13 +1101,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: req.Stream,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
@@ -787,7 +1173,14 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
@@ -795,7 +1188,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
@@ -832,15 +1225,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
}
// 客户端没传,根据模型生成
var modelID string
var reqMap map[string]any
if json.Unmarshal(body, &reqMap) == nil {
if m, ok := reqMap["model"].(string); ok {
modelID = m
}
}
// haiku模型不需要claude-code beta
// haiku 模型不需要 claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader
}
@@ -848,6 +1233,83 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
@@ -860,6 +1322,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode {
case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401:
@@ -1248,13 +1720,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
log.Printf("Update subscription cache failed: %v", err)
}
}()
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
@@ -1263,13 +1729,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
log.Printf("Update balance cache failed: %v", err)
}
}()
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
@@ -1281,7 +1741,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity {
@@ -1291,14 +1759,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
}
}
}
@@ -1311,7 +1777,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
@@ -1324,7 +1790,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
@@ -1345,6 +1811,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
@@ -1363,7 +1841,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
@@ -1424,7 +1902,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil

View File

@@ -0,0 +1,50 @@
package service
import (
"strconv"
"testing"
)
var benchmarkStringSink string
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
svc := &GatewayService{}
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
parsed, err := ParseGatewayRequest(body)
if err != nil {
b.Fatalf("解析请求失败: %v", err)
}
benchmarkStringSink = svc.GenerateSessionHash(parsed)
}
}
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
func BenchmarkExtractCacheableContent_System(b *testing.B) {
svc := &GatewayService{}
req := buildSystemCacheableRequest(12)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkStringSink = svc.extractCacheableContent(req)
}
}
func buildSystemCacheableRequest(parts int) *ParsedRequest {
systemParts := make([]any, 0, parts)
for i := 0; i < parts; i++ {
systemParts = append(systemParts, map[string]any{
"text": "system_part_" + strconv.Itoa(i),
"cache_control": map[string]any{
"type": "ephemeral",
},
})
}
return &ParsedRequest{
System: systemParts,
HasSystem: true,
}
}

View File

@@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
valid = true
}
if valid {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
@@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
@@ -472,7 +493,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
@@ -725,7 +746,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = idHeader
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < geminiMaxRetries {
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
@@ -921,7 +942,10 @@ func sleepGeminiBackoff(attempt int) {
time.Sleep(sleepFor)
}
var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
var (
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
)
func sanitizeUpstreamErrorMessage(msg string) string {
if msg == "" {
@@ -1753,7 +1777,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
@@ -1883,13 +1907,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
ra := time.Now().Add(5 * time.Minute)
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
@@ -1925,7 +1980,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
}
// Match "Please retry in Xs"
retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`)
matches := retryInRegex.FindStringSubmatch(string(body))
if len(matches) == 2 {
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
@@ -1946,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
}
func nextGeminiDailyResetUnix() *int64 {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
// Fallback: PST without DST.
loc = time.FixedZone("PST", -8*3600)
}
now := time.Now().In(loc)
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
if !reset.After(now) {
reset = reset.Add(24 * time.Hour)
}
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}
@@ -2243,16 +2288,46 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
if !ok {
continue
}
name, _ := tm["name"].(string)
desc, _ := tm["description"].(string)
params := tm["input_schema"]
var name, desc string
var params any
// 检查是否为 custom 类型工具 (MCP)
toolType, _ := tm["type"].(string)
if toolType == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
custom, ok := tm["custom"].(map[string]any)
if !ok {
continue
}
name, _ = tm["name"].(string)
desc, _ = custom["description"].(string)
params = custom["input_schema"]
} else {
// 标准格式: 从顶层字段获取
name, _ = tm["name"].(string)
desc, _ = tm["description"].(string)
params = tm["input_schema"]
}
if name == "" {
continue
}
// 为 nil params 提供默认值
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{
"name": name,
"description": desc,
"parameters": params,
"parameters": cleanedParams,
})
}
@@ -2266,6 +2341,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
}
}
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {

View File

@@ -0,0 +1,128 @@
package service
import (
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
name string
tools any
expectedLen int
description string
}{
{
name: "Standard tools",
tools: []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []any{
map[string]any{
"type": "custom",
"name": "mcp_tool",
"custom": map[string]any{
"description": "MCP tool description",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从custom字段读取",
},
{
name: "Mixed standard and custom tools",
tools: []any{
map[string]any{
"name": "standard_tool",
"description": "Standard",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "custom",
"name": "custom_tool",
"custom": map[string]any{
"description": "Custom",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "混合工具应该都能正确转换",
},
{
name: "Custom tool without custom field",
tools: []any{
map[string]any{
"type": "custom",
"name": "invalid_custom",
// 缺少 custom 字段
},
},
expectedLen: 0, // 应该被跳过
description: "缺少custom字段的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertClaudeToolsToGeminiTools(tt.tools)
if tt.expectedLen == 0 {
if result != nil {
t.Errorf("%s: expected nil result, got %v", tt.description, result)
}
return
}
if result == nil {
t.Fatalf("%s: expected non-nil result", tt.description)
}
if len(result) != 1 {
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
return
}
toolDecl, ok := result[0].(map[string]any)
if !ok {
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
}
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
if !ok {
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
}
toolsArr, _ := tt.tools.([]any)
expectedFuncCount := 0
for _, tool := range toolsArr {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] != "" {
// 检查是否为有效的custom工具
if toolMap["type"] == "custom" {
if toolMap["custom"] != nil {
expectedFuncCount++
}
} else {
expectedFuncCount++
}
}
}
if len(funcDecls) != expectedFuncCount {
t.Errorf("%s: expected %d function declarations, got %d",
tt.description, expectedFuncCount, len(funcDecls))
}
})
}
}

View File

@@ -7,13 +7,14 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
type GeminiOAuthService struct {
@@ -163,6 +164,45 @@ type GeminiTokenInfo struct {
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
@@ -219,25 +259,45 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式project_id 是必需的
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" {
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
} else {
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
} else {
tierID = fetchedTierID
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
if tierID == "" {
tierID = "LEGACY"
}
}
return &GeminiTokenInfo{
@@ -248,6 +308,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
@@ -266,8 +327,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -354,18 +422,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
tokenInfo.ProjectID = existingProjectID
}
// 尝试从账号凭证获取 tierID向后兼容
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
if oauthType == "code_assist" {
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
} else {
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
}
}
}
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
@@ -388,6 +477,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
@@ -398,33 +494,22 @@ func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil {
return "", errors.New("code assist client not configured")
return "", "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
// If LoadCodeAssist returned a project, use it
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
req := &geminicli.OnboardUserRequest{
@@ -443,39 +528,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
return "", err
return "", tierID, err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
return strings.TrimSpace(v), tierID, nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
return strings.TrimSpace(id), tierID, nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
return "", errors.New("onboardUser completed but no project_id returned")
return "", tierID, errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {
@@ -497,11 +582,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
client := &http.Client{Timeout: 30 * time.Second}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
}
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: strings.TrimSpace(proxyURL),
Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Do(req)

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type geminiModelClass string
const (
geminiModelPro geminiModelClass = "pro"
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Cooldown time.Duration
}
type GeminiQuotaPolicy struct {
tiers map[string]GeminiTierPolicy
}
type GeminiUsageTotals struct {
ProRequests int64
FlashRequests int64
ProTokens int64
FlashTokens int64
ProCost float64
FlashCost float64
}
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
mu sync.Mutex
cachedAt time.Time
policy *GeminiQuotaPolicy
}
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
return &GeminiQuotaService{
cfg: cfg,
settingRepo: settingRepo,
}
}
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s == nil {
return newGeminiQuotaPolicy()
}
now := time.Now()
s.mu.Lock()
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
policy := s.policy
s.mu.Unlock()
return policy
}
s.mu.Unlock()
policy := newGeminiQuotaPolicy()
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
if s.settingRepo != nil {
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
s.mu.Lock()
s.policy = policy
s.cachedAt = now
s.mu.Unlock()
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
policy := s.Policy(ctx)
return policy.CooldownForTier(tierID)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
},
}
}
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
if p == nil || len(tiers) == 0 {
return
}
for rawID, override := range tiers {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
policy.Cooldown = time.Duration(minutes) * time.Minute
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
}
return policy.Quota, true
}
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
policy, ok := p.policyForTier(tierID)
if ok && policy.Cooldown > 0 {
return policy.Cooldown
}
return 5 * time.Minute
}
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
if p == nil {
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func clampGeminiQuotaInt(value int) int {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
return geminiModelFlash
}
return geminiModelPro
}
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
var totals GeminiUsageTotals
for _, stat := range stats {
switch geminiModelClassFromName(stat.Model) {
case geminiModelFlash:
totals.FlashRequests += stat.Requests
totals.FlashTokens += stat.TotalTokens
totals.FlashCost += stat.ActualCost
default:
totals.ProRequests += stat.Requests
totals.ProTokens += stat.TotalTokens
totals.ProCost += stat.ActualCost
}
}
return totals
}
func geminiQuotaLocation() *time.Location {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
return time.FixedZone("PST", -8*3600)
}
return loc
}
func geminiDailyWindowStart(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
}
func geminiDailyResetTime(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
reset := start.Add(24 * time.Hour)
if !reset.After(localNow) {
reset = reset.Add(24 * time.Hour)
}
return reset
}

View File

@@ -50,7 +50,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
// 2) Refresh if needed (pre-expiry skew).
expiresAt := parseExpiresAt(account)
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
@@ -66,7 +66,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if err == nil && fresh != nil {
account = fresh
}
expiresAt = parseExpiresAt(account)
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
if p.geminiOAuthService == nil {
return "", errors.New("gemini oauth service not configured")
@@ -83,7 +83,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = parseExpiresAt(account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
@@ -112,17 +112,21 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil
}
detected = strings.TrimSpace(detected)
tierID = strings.TrimSpace(tierID)
if detected != "" {
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["project_id"] = detected
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account)
}
}
@@ -154,18 +158,3 @@ func geminiTokenCacheKey(account *Account) string {
}
return "account:" + strconv.FormatInt(account.ID, 10)
}
func parseExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}

View File

@@ -2,7 +2,6 @@ package service
import (
"context"
"strconv"
"time"
)
@@ -22,16 +21,11 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time
if !r.CanRefresh(account) {
return false
}
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < refreshWindow
return time.Until(*expiresAt) < refreshWindow
}
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {

View File

@@ -11,10 +11,11 @@ type Group struct {
IsExclusive bool
Status string
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
DefaultValidityDays int
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -4,7 +4,7 @@ import (
"context"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)

View File

@@ -2,8 +2,29 @@ package service
import "net/http"
// HTTPUpstream interface for making HTTP requests to upstream APIs (Claude, OpenAI, etc.)
// This is a generic interface that can be used for any HTTP-based upstream service.
// HTTPUpstream 上游 HTTP 请求接口
// 用于向上游 APIClaude、OpenAI、Gemini 等)发送请求
// 这是一个通用接口,可用于任何基于 HTTP 的上游服务
//
// 设计说明:
// - 支持可选代理配置
// - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用
type HTTPUpstream interface {
Do(req *http.Request, proxyURL string) (*http.Response, error)
// Do 执行 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID用于连接池隔离隔离策略为 account 或 account_proxy 时生效)
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// 注意:
// - 调用方必须关闭 resp.Body否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
}

Some files were not shown because too many files have changed in this diff Show More