Merge branch 'dev-release'
This commit is contained in:
@@ -1 +1 @@
|
||||
0.1.70
|
||||
0.1.70.2
|
||||
|
||||
@@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
|
||||
redeemCache := repository.NewRedeemCache(redisClient)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||
|
||||
@@ -38,31 +38,32 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
Billing BillingConfig `mapstructure:"billing"`
|
||||
Turnstile TurnstileConfig `mapstructure:"turnstile"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Ops OpsConfig `mapstructure:"ops"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Totp TotpConfig `mapstructure:"totp"`
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
|
||||
SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"`
|
||||
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
|
||||
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
|
||||
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
|
||||
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
Update UpdateConfig `mapstructure:"update"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
@@ -147,6 +148,7 @@ type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"` // debug/release
|
||||
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接
|
||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP)
|
||||
@@ -226,6 +228,9 @@ type GatewayConfig struct {
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
|
||||
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
|
||||
|
||||
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
|
||||
// MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@@ -525,6 +530,13 @@ type APIKeyAuthCacheConfig struct {
|
||||
Singleflight bool `mapstructure:"singleflight"`
|
||||
}
|
||||
|
||||
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
|
||||
type SubscriptionCacheConfig struct {
|
||||
L1Size int `mapstructure:"l1_size"`
|
||||
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
|
||||
JitterPercent int `mapstructure:"jitter_percent"`
|
||||
}
|
||||
|
||||
// DashboardCacheConfig 仪表盘统计缓存配置
|
||||
type DashboardCacheConfig struct {
|
||||
// Enabled: 是否启用仪表盘缓存
|
||||
@@ -630,6 +642,7 @@ func Load() (*Config, error) {
|
||||
if cfg.Server.Mode == "" {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL)
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
|
||||
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
|
||||
@@ -702,7 +715,8 @@ func setDefaults() {
|
||||
// Server
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.mode", "release")
|
||||
viper.SetDefault("server.frontend_url", "")
|
||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||
viper.SetDefault("server.trusted_proxies", []string{})
|
||||
@@ -737,7 +751,7 @@ func setDefaults() {
|
||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
|
||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
|
||||
viper.SetDefault("security.response_headers.enabled", false)
|
||||
viper.SetDefault("security.response_headers.enabled", true)
|
||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||
viper.SetDefault("security.csp.enabled", true)
|
||||
@@ -775,9 +789,9 @@ func setDefaults() {
|
||||
viper.SetDefault("database.user", "postgres")
|
||||
viper.SetDefault("database.password", "postgres")
|
||||
viper.SetDefault("database.dbname", "sub2api")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
viper.SetDefault("database.max_open_conns", 50)
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.sslmode", "prefer")
|
||||
viper.SetDefault("database.max_open_conns", 256)
|
||||
viper.SetDefault("database.max_idle_conns", 128)
|
||||
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
|
||||
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
|
||||
|
||||
@@ -789,8 +803,8 @@ func setDefaults() {
|
||||
viper.SetDefault("redis.dial_timeout_seconds", 5)
|
||||
viper.SetDefault("redis.read_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.write_timeout_seconds", 3)
|
||||
viper.SetDefault("redis.pool_size", 128)
|
||||
viper.SetDefault("redis.min_idle_conns", 10)
|
||||
viper.SetDefault("redis.pool_size", 1024)
|
||||
viper.SetDefault("redis.min_idle_conns", 128)
|
||||
viper.SetDefault("redis.enable_tls", false)
|
||||
|
||||
// Ops (vNext)
|
||||
@@ -849,6 +863,11 @@ func setDefaults() {
|
||||
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
|
||||
viper.SetDefault("api_key_auth_cache.singleflight", true)
|
||||
|
||||
// Subscription auth L1 cache
|
||||
viper.SetDefault("subscription_cache.l1_size", 16384)
|
||||
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
|
||||
viper.SetDefault("subscription_cache.jitter_percent", 10)
|
||||
|
||||
// Dashboard cache
|
||||
viper.SetDefault("dashboard_cache.enabled", true)
|
||||
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
|
||||
@@ -882,13 +901,14 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.failover_on_400", false)
|
||||
viper.SetDefault("gateway.max_account_switches", 10)
|
||||
viper.SetDefault("gateway.max_account_switches_gemini", 3)
|
||||
viper.SetDefault("gateway.force_codex_cli", false)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
|
||||
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
|
||||
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
|
||||
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
|
||||
viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+)
|
||||
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒)
|
||||
viper.SetDefault("gateway.max_upstream_clients", 5000)
|
||||
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
|
||||
@@ -933,6 +953,22 @@ func setDefaults() {
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if strings.TrimSpace(c.Server.FrontendURL) != "" {
|
||||
if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil {
|
||||
return fmt.Errorf("server.frontend_url invalid: %w", err)
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL))
|
||||
if err != nil {
|
||||
return fmt.Errorf("server.frontend_url invalid: %w", err)
|
||||
}
|
||||
if u.RawQuery != "" || u.ForceQuery {
|
||||
return fmt.Errorf("server.frontend_url invalid: must not include query")
|
||||
}
|
||||
if u.User != nil {
|
||||
return fmt.Errorf("server.frontend_url invalid: must not include userinfo")
|
||||
}
|
||||
warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL)
|
||||
}
|
||||
if c.JWT.ExpireHour <= 0 {
|
||||
return fmt.Errorf("jwt.expire_hour must be positive")
|
||||
}
|
||||
|
||||
@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
||||
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
|
||||
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
|
||||
}
|
||||
if cfg.Security.ResponseHeaders.Enabled {
|
||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
||||
if !cfg.Security.ResponseHeaders.Enabled {
|
||||
t.Fatalf("ResponseHeaders.Enabled = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultServerMode(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Mode != "release" {
|
||||
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Database.SSLMode != "prefer" {
|
||||
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -424,6 +450,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateServerFrontendURL(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://example.com"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() frontend_url valid error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://example.com/path"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate() frontend_url with path valid error: %v", err)
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://example.com?utm=1"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("Validate() should reject server.frontend_url with query")
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "https://user:pass@example.com"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("Validate() should reject server.frontend_url with userinfo")
|
||||
}
|
||||
|
||||
cfg.Server.FrontendURL = "/relative"
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatalf("Validate() should reject relative server.frontend_url")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFrontendRedirectURL(t *testing.T) {
|
||||
if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
|
||||
t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
|
||||
|
||||
@@ -3,6 +3,7 @@ package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
// 阶段一:预验证所有账号存在,收集 credentials
|
||||
type accountUpdate struct {
|
||||
ID int64
|
||||
Credentials map[string]any
|
||||
}
|
||||
updates := make([]accountUpdate, 0, len(req.AccountIDs))
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
|
||||
return
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
|
||||
}
|
||||
|
||||
// Update account
|
||||
// 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
|
||||
for _, u := range updates {
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
Credentials: u.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
|
||||
response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
|
||||
return
|
||||
}
|
||||
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
"success": len(updates),
|
||||
"failed": 0,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
200
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
200
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
|
||||
type failingAdminService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
updateCallCount atomic.Int64
|
||||
}
|
||||
|
||||
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
f.updateCallCount.Add(1)
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("database error")
|
||||
}
|
||||
return f.stubAdminService.UpdateAccount(ctx, id, input)
|
||||
}
|
||||
|
||||
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
|
||||
return router, handler
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test-uuid",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
|
||||
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
|
||||
// 让第 2 个账号(ID=2)更新时失败
|
||||
svc := &failingAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 2,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "org_uuid",
|
||||
Value: "test-org",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
|
||||
// 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用
|
||||
require.Equal(t, int64(2), svc.updateCallCount.Load(),
|
||||
"fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
|
||||
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
|
||||
svc := &getAccountFailingService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 1,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
|
||||
}
|
||||
|
||||
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
|
||||
type getAccountFailingService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
}
|
||||
|
||||
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return f.stubAdminService.GetAccount(ctx, id)
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": "not-a-bool",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"intercept_warmup_requests 传入非 bool 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": true,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"intercept_warmup_requests 传入合法 bool 值应返回 200")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入非 string 类型(number),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": 12345,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"account_uuid 传入非 string 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入 null(设置为空),应正常通过
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": nil,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"account_uuid 传入 null 应返回 200")
|
||||
}
|
||||
@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
|
||||
func truncateSearchByRune(search string, maxRunes int) string {
|
||||
if runes := []rune(search); len(runes) > maxRunes {
|
||||
return string(runes[:maxRunes])
|
||||
}
|
||||
return search
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxRunes int
|
||||
wantLen int // 期望的 rune 长度
|
||||
}{
|
||||
{
|
||||
name: "纯中文超长",
|
||||
input: string(make([]rune, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "纯 ASCII 超长",
|
||||
input: string(make([]byte, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
maxRunes: 100,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "恰好 100 个字符",
|
||||
input: string(make([]rune, 100)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "不足 100 字符不截断",
|
||||
input: "hello世界",
|
||||
maxRunes: 100,
|
||||
wantLen: 7,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := truncateSearchByRune(tc.input, tc.maxRunes)
|
||||
require.Equal(t, tc.wantLen, len([]rune(result)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
|
||||
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
|
||||
input := ""
|
||||
for i := 0; i < 101; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
require.Equal(t, 100, len([]rune(result)))
|
||||
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
|
||||
require.Equal(t, 300, len(result))
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
|
||||
// 50 个 ASCII + 51 个中文 = 101 个 rune
|
||||
input := ""
|
||||
for i := 0; i < 50; i++ {
|
||||
input += "a"
|
||||
}
|
||||
for i := 0; i < 51; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
runes := []rune(result)
|
||||
require.Equal(t, 100, len(runes))
|
||||
// 前 50 个应该是 'a',后 50 个应该是 '中'
|
||||
require.Equal(t, 'a', runes[0])
|
||||
require.Equal(t, 'a', runes[49])
|
||||
require.Equal(t, '中', runes[50])
|
||||
require.Equal(t, '中', runes[99])
|
||||
}
|
||||
@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
if runes := []rune(search); len(runes) > 100 {
|
||||
search = string(runes[:100])
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -448,17 +449,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
|
||||
@@ -236,7 +236,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
@@ -284,12 +285,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
// Ensure the wait counter is decremented if we exit before acquiring the slot.
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -301,14 +302,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
@@ -398,7 +397,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
@@ -446,11 +446,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -462,13 +463,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
@@ -967,7 +967,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
@@ -1238,7 +1239,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = err.Error()
|
||||
log.Printf("[Gateway] billing error details: %v", err)
|
||||
msg = "Billing error"
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type OpenAIGatewayHandler struct {
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -54,6 +55,7 @@ func NewOpenAIGatewayHandler(
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +111,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if !openai.IsCodexCLIRequest(userAgent) {
|
||||
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
|
||||
if !isCodexCLI {
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) == "" {
|
||||
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||
@@ -218,7 +221,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
@@ -251,11 +255,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -267,13 +272,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
// generateRandomID 生成密码学安全的随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
randBytes := make([]byte, 12)
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
panic("crypto/rand unavailable: " + err.Error())
|
||||
}
|
||||
for i, b := range randBytes {
|
||||
result[i] = chars[int(b)%len(chars)]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
require.Len(t, id, 12, "ID 长度应为 12")
|
||||
_, dup := seen[id]
|
||||
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Charset(t *testing.T) {
|
||||
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
validSet := make(map[byte]struct{}, len(validChars))
|
||||
for i := 0; i < len(validChars); i++ {
|
||||
validSet[validChars[i]] = struct{}{}
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := generateRandomID()
|
||||
for j := 0; j < len(id); j++ {
|
||||
_, ok := validSet[id[j]]
|
||||
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -54,29 +54,34 @@ func normalizeIP(ip string) string {
|
||||
return ip
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||
var privateNets []*net.IPNet
|
||||
|
||||
// 私有 IP 范围
|
||||
privateBlocks := []string{
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
|
||||
for _, block := range privateBlocks {
|
||||
_, cidr, err := net.ParseCIDR(block)
|
||||
} {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
panic("invalid CIDR: " + cidr)
|
||||
}
|
||||
if cidr.Contains(ip) {
|
||||
privateNets = append(privateNets, block)
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, block := range privateNets {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
51
backend/internal/pkg/ip/ip_test.go
Normal file
51
backend/internal/pkg/ip/ip_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
//go:build unit
|
||||
|
||||
package ip
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// 私有 IPv4
|
||||
{"10.x 私有地址", "10.0.0.1", true},
|
||||
{"10.x 私有地址段末", "10.255.255.255", true},
|
||||
{"172.16.x 私有地址", "172.16.0.1", true},
|
||||
{"172.31.x 私有地址", "172.31.255.255", true},
|
||||
{"192.168.x 私有地址", "192.168.1.1", true},
|
||||
{"127.0.0.1 本地回环", "127.0.0.1", true},
|
||||
{"127.x 回环段", "127.255.255.255", true},
|
||||
|
||||
// 公网 IPv4
|
||||
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
|
||||
{"1.1.1.1 公网", "1.1.1.1", false},
|
||||
{"172.15.255.255 非私有", "172.15.255.255", false},
|
||||
{"172.32.0.0 非私有", "172.32.0.0", false},
|
||||
{"11.0.0.1 公网", "11.0.0.1", false},
|
||||
|
||||
// IPv6
|
||||
{"::1 IPv6 回环", "::1", true},
|
||||
{"fc00:: IPv6 私有", "fc00::1", true},
|
||||
{"fd00:: IPv6 私有", "fd00::1", true},
|
||||
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
|
||||
|
||||
// 无效输入
|
||||
{"空字符串", "", false},
|
||||
{"非法字符串", "not-an-ip", false},
|
||||
{"不完整 IP", "192.168", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := isPrivateIP(tc.ip)
|
||||
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -50,6 +50,7 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
|
||||
43
backend/internal/pkg/oauth/oauth_test.go
Normal file
43
backend/internal/pkg/oauth/oauth_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
@@ -47,6 +47,7 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
@@ -92,7 +93,9 @@ func (s *SessionStore) Delete(sessionID string) {
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
|
||||
43
backend/internal/pkg/openai/oauth_test.go
Normal file
43
backend/internal/pkg/openai/oauth_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
|
||||
@@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
||||
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
|
||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
// Use raw SQL for atomic increment to avoid race conditions
|
||||
// First get current value
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldQuotaUsed).
|
||||
Only(ctx)
|
||||
updated, err := r.client.APIKey.UpdateOneID(id).
|
||||
Where(apikey.DeletedAtIsNil()).
|
||||
AddQuotaUsed(amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newValue := m.QuotaUsed + amount
|
||||
|
||||
// Update with new value
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetQuotaUsed(newValue).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected == 0 {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
return newValue, nil
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
|
||||
@@ -4,11 +4,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
// --- IncrementQuotaUsed ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
|
||||
user := s.mustCreateUser("incr-basic@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
|
||||
|
||||
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed")
|
||||
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
|
||||
|
||||
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed second")
|
||||
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
user := s.mustCreateUser("incr-deleted@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
|
||||
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户和 API Key
|
||||
u, err := client.User.Create().
|
||||
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
|
||||
SetPasswordHash("hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
|
||||
Name: "Concurrent",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, k), "create api key")
|
||||
t.Cleanup(func() {
|
||||
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
|
||||
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
|
||||
})
|
||||
|
||||
// 10 个 goroutine 各递增 1.0,总计应为 10.0
|
||||
const goroutines = 10
|
||||
const increment = 1.0
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, e := range errs {
|
||||
require.NoError(t, e, "goroutine %d failed", i)
|
||||
}
|
||||
|
||||
// 验证最终结果
|
||||
got, err := repo.GetByID(ctx, k.ID)
|
||||
require.NoError(t, err, "GetByID")
|
||||
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
|
||||
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +17,15 @@ const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
func jitteredTTL() time.Duration {
|
||||
jitter := time.Duration(rand.Int63n(int64(2*billingCacheJitter))) - billingCacheJitter
|
||||
return billingCacheTTL + jitter
|
||||
}
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
@@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
pipe.Expire(ctx, key, jitteredTTL())
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, cache service.BillingCache)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "key_not_exists_returns_nil",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing_key_deducts_successfully",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||
|
||||
bal, err := cache.GetUserBalance(ctx, 200)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled_context_propagates_error",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel() // 立即取消
|
||||
|
||||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
tt.fn(ctx, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||
s.Run("key_not_exists_returns_nil", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||
})
|
||||
|
||||
s.Run("cancelled_context_propagates_error", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package repository
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL(t *testing.T) {
|
||||
const (
|
||||
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||
)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||
// 多次调用应该产生不同的值(验证抖动存在)
|
||||
seen := make(map[time.Duration]struct{}, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
seen[jitteredTTL()] = struct{}{}
|
||||
}
|
||||
// 50 次调用中应该至少有 2 个不同的值
|
||||
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,22 @@ import (
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
"hour": "YYYY-MM-DD HH24:00",
|
||||
"day": "YYYY-MM-DD",
|
||||
"week": "IYYY-IW",
|
||||
"month": "YYYY-MM",
|
||||
}
|
||||
|
||||
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
|
||||
func safeDateFormat(granularity string) string {
|
||||
if f, ok := dateFormatWhitelist[granularity]; ok {
|
||||
return f
|
||||
}
|
||||
return "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
@@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string {
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH top_keys AS (
|
||||
@@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
|
||||
|
||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH top_users AS (
|
||||
@@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
@@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||
result := make(map[int64]*BatchUserUsageStats)
|
||||
if len(userIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 默认最近 30 天
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now().AddDate(0, 0, -30)
|
||||
}
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range userIDs {
|
||||
result[id] = &BatchUserUsageStats{UserID: id}
|
||||
}
|
||||
@@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
query := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1)
|
||||
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 默认最近 30 天
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now().AddDate(0, 0, -30)
|
||||
}
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
@@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
query := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1)
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
|
||||
@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err, "GetBatchUserUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
s.Require().NotNil(stats[user1.ID])
|
||||
@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeDateFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
granularity string
|
||||
expected string
|
||||
}{
|
||||
// 合法值
|
||||
{"hour", "hour", "YYYY-MM-DD HH24:00"},
|
||||
{"day", "day", "YYYY-MM-DD"},
|
||||
{"week", "week", "IYYY-IW"},
|
||||
{"month", "month", "YYYY-MM"},
|
||||
|
||||
// 非法值回退到默认
|
||||
{"空字符串", "", "YYYY-MM-DD"},
|
||||
{"未知粒度 year", "year", "YYYY-MM-DD"},
|
||||
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
|
||||
|
||||
// 恶意字符串
|
||||
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
|
||||
{"带引号", "day'", "YYYY-MM-DD"},
|
||||
{"带括号", "day)", "YYYY-MM-DD"},
|
||||
{"Unicode", "日", "YYYY-MM-DD"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := safeDateFormat(tc.granularity)
|
||||
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -592,13 +592,13 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
RunMode: config.RunModeStandard,
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo, nil)
|
||||
userService := service.NewUserService(userRepo, nil, nil)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, cfg)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||
@@ -1602,11 +1602,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -176,6 +176,12 @@ func validateJWTForAdmin(
|
||||
return false
|
||||
}
|
||||
|
||||
// 校验 TokenVersion,确保管理员改密后旧 token 失效
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if !user.IsAdmin() {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
|
||||
194
backend/internal/server/middleware/admin_auth_test.go
Normal file
194
backend/internal/server/middleware/admin_auth_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
Email: "admin@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 2,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
userRepo := &stubUserRepo{
|
||||
getByID: func(ctx context.Context, id int64) (*service.User, error) {
|
||||
if id != admin.ID {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
clone := *admin
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
userService := service.NewUserService(userRepo, nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
|
||||
})
|
||||
|
||||
t.Run("token_version_match_allows", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
|
||||
})
|
||||
|
||||
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
type stubUserRepo struct {
|
||||
getByID func(ctx context.Context, id int64) (*service.User, error)
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
if s.getByID == nil {
|
||||
panic("GetByID not stubbed")
|
||||
}
|
||||
return s.getByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
panic("unexpected GetByEmail call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected UpdateBalance call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected DeductBalance call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
panic("unexpected ExistsByEmail call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -134,7 +133,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:验证订阅
|
||||
// 订阅模式:获取订阅(L1 缓存 + singleflight)
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
@@ -145,30 +144,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 激活滑动窗口(首次使用时)
|
||||
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 检查并重置过期窗口
|
||||
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
// 合并验证 + 限额检查(纯内存操作)
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
}
|
||||
AbortWithError(c, status, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@@ -99,7 +99,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
|
||||
@@ -36,8 +36,8 @@ type UsageLogRepository interface {
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
|
||||
@@ -1582,6 +1582,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// ForwardUpstream 透传请求到上游 Antigravity 服务
|
||||
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
|
||||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
|
||||
// 获取上游配置
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if baseURL == "" || apiKey == "" {
|
||||
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
// 解析请求获取模型信息
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||||
|
||||
// 透传 Claude 相关 headers
|
||||
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||
req.Header.Set("anthropic-version", v)
|
||||
}
|
||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||||
req.Header.Set("anthropic-beta", v)
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 429 错误时标记账号限流
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
|
||||
}
|
||||
|
||||
// 透传上游错误
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 处理成功响应(流式/非流式)
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
if claudeReq.Stream {
|
||||
// 流式响应:透传
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
|
||||
} else {
|
||||
// 非流式响应:直接透传
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
|
||||
// 提取 usage
|
||||
usage = s.extractClaudeUsage(respBody)
|
||||
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
}
|
||||
|
||||
// 构建计费结果
|
||||
duration := time.Since(startTime)
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// streamUpstreamResponse 透传上游流式响应并提取 usage
|
||||
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
var firstTokenRecorded bool
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
// 记录首 token 时间
|
||||
if !firstTokenRecorded && len(line) > 0 {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
firstTokenRecorded = true
|
||||
}
|
||||
|
||||
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
dataStr := bytes.TrimPrefix(line, []byte("data: "))
|
||||
var event map[string]any
|
||||
if json.Unmarshal(dataStr, &event) == nil {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 透传行
|
||||
_, _ = c.Writer.Write(line)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
return usage, firstTokenMs
|
||||
}
|
||||
|
||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||||
usage := &ClaudeUsage{}
|
||||
var resp map[string]any
|
||||
if json.Unmarshal(body, &resp) != nil {
|
||||
return usage
|
||||
}
|
||||
if u, ok := resp["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -1613,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(time.Now()),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
default:
|
||||
@@ -2288,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
@@ -2309,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -2320,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
@@ -2445,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
@@ -2473,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -2484,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
@@ -2888,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
var firstTokenMs *int
|
||||
var last map[string]any
|
||||
@@ -2914,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -2925,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
@@ -3068,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||||
@@ -3100,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -3111,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n"))
|
||||
}()
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
start := time.Now().Add(-10 * time.Millisecond)
|
||||
usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 1, usage.InputTokens)
|
||||
// 第二次事件覆盖 output_tokens
|
||||
require.Equal(t, 5, usage.OutputTokens)
|
||||
require.Equal(t, 3, usage.CacheReadInputTokens)
|
||||
require.Equal(t, 4, usage.CacheCreationInputTokens)
|
||||
|
||||
if firstTokenMs == nil {
|
||||
t.Fatalf("expected firstTokenMs to be set")
|
||||
}
|
||||
// 确保有透传输出
|
||||
require.True(t, strings.Contains(writer.Body.String(), "data:"))
|
||||
}
|
||||
|
||||
@@ -6,8 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
|
||||
singleflight bool
|
||||
}
|
||||
|
||||
var (
|
||||
jitterRandMu sync.Mutex
|
||||
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
)
|
||||
|
||||
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||
if cfg == nil {
|
||||
return apiKeyAuthCacheConfig{}
|
||||
@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||
return c.negativeTTL > 0
|
||||
}
|
||||
|
||||
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
|
||||
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
|
||||
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return ttl
|
||||
@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
percent = 100
|
||||
}
|
||||
delta := float64(percent) / 100
|
||||
jitterRandMu.Lock()
|
||||
randVal := jitterRand.Float64()
|
||||
jitterRandMu.Unlock()
|
||||
randVal := rand.Float64()
|
||||
factor := 1 - delta + randVal*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
|
||||
@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
@@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
|
||||
// replaceModelInResponseBody 替换响应体中的model字段
|
||||
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
|
||||
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
return body
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
|
||||
52
backend/internal/service/gateway_service_streaming_test.go
Normal file
52
backend/internal/service/gateway_service_streaming_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// Minimal SSE event to trigger parseSSEUsage
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", nil, false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 3, result.usage.InputTokens)
|
||||
require.Equal(t, 7, result.usage.OutputTokens)
|
||||
}
|
||||
@@ -2,19 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||
codexCacheTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
//go:embed prompts/codex_cli_instructions.md
|
||||
@@ -77,12 +65,6 @@ type codexTransformResult struct {
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
type opencodeCacheMetadata struct {
|
||||
ETag string `json:"etag"`
|
||||
LastFetch string `json:"lastFetch,omitempty"`
|
||||
LastChecked int64 `json:"lastChecked"`
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||
cacheDir := codexCachePath("")
|
||||
if cacheDir == "" {
|
||||
return ""
|
||||
}
|
||||
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||
|
||||
var cachedContent string
|
||||
if content, ok := readFile(cacheFile); ok {
|
||||
cachedContent = content
|
||||
}
|
||||
|
||||
var meta opencodeCacheMetadata
|
||||
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||
return cachedContent
|
||||
}
|
||||
}
|
||||
|
||||
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||
return cachedContent
|
||||
}
|
||||
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||
_ = writeFile(cacheFile, content)
|
||||
meta = opencodeCacheMetadata{
|
||||
ETag: etag,
|
||||
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||
LastChecked: time.Now().UnixMilli(),
|
||||
}
|
||||
_ = writeJSON(metaFile, meta)
|
||||
return content
|
||||
}
|
||||
|
||||
return cachedContent
|
||||
}
|
||||
|
||||
func getOpenCodeCodexHeader() string {
|
||||
// 优先从 opencode 仓库缓存获取指令。
|
||||
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||
|
||||
// 若 opencode 指令可用,直接返回。
|
||||
if opencodeInstructions != "" {
|
||||
return opencodeInstructions
|
||||
}
|
||||
|
||||
// 否则回退使用本地 Codex CLI 指令。
|
||||
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
|
||||
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string {
|
||||
}
|
||||
|
||||
// applyInstructions 处理 instructions 字段
|
||||
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
|
||||
// isCodexCLI=false: 优先使用 opencode 指令覆盖
|
||||
// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令)
|
||||
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
|
||||
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
if isCodexCLI {
|
||||
return applyCodexCLIInstructions(reqBody)
|
||||
@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
}
|
||||
|
||||
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
|
||||
// 仅在 instructions 为空时添加 opencode 指令
|
||||
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
|
||||
func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||
if !isInstructionsEmpty(reqBody) {
|
||||
return false // 已有有效 instructions,不修改
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
instructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
return true
|
||||
@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
|
||||
// 优先使用 opencode 指令覆盖
|
||||
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
|
||||
// 优先使用内置 Codex CLI 指令覆盖
|
||||
func applyOpenCodeInstructions(reqBody map[string]any) bool {
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
@@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func codexCachePath(filename string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||
if filename == "" {
|
||||
return cacheDir
|
||||
}
|
||||
return filepath.Join(cacheDir, filename)
|
||||
}
|
||||
|
||||
func readFile(path string) (string, bool) {
|
||||
if path == "" {
|
||||
return "", false
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(data), true
|
||||
}
|
||||
|
||||
func writeFile(path, content string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty cache path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
func loadJSON(path string, target any) bool {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(data, target); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func writeJSON(path string, value any) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty json path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "sub2api-codex")
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", "", resp.StatusCode, err
|
||||
}
|
||||
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
// 显式 store=true 也会强制为 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"tools": []any{
|
||||
@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
for input, expected := range cases {
|
||||
require.Equal(t, expected, normalizeCodexModel(input))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
|
||||
// Codex CLI 场景:已有 instructions 时保持不变
|
||||
setupCodexCache(t)
|
||||
// Codex CLI 场景:已有 instructions 时不修改
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"instructions": "user custom instructions",
|
||||
"input": []any{},
|
||||
"instructions": "existing instructions",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true)
|
||||
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "user custom instructions", instructions)
|
||||
// instructions 未变,但其他字段(如 store、stream)可能被修改
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) {
|
||||
// Codex CLI 场景:无 instructions 时补充内置指令
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true)
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, instructions)
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header)
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// 使用临时 HOME 避免触发网络拉取 header。
|
||||
// Windows 使用 USERPROFILE,Unix 使用 HOME。
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("HOME", tempDir)
|
||||
t.Setenv("USERPROFILE", tempDir)
|
||||
|
||||
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||
|
||||
meta := map[string]any{
|
||||
"etag": "",
|
||||
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||
"lastChecked": time.Now().UnixMilli(),
|
||||
}
|
||||
data, err := json.Marshal(meta)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||
require.Equal(t, "existing instructions", instructions)
|
||||
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
|
||||
_ = result
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
|
||||
// Codex CLI 场景:无 instructions 时补充默认值
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:使用 opencode 指令覆盖
|
||||
setupCodexCache(t)
|
||||
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
|
||||
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
|
||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
|
||||
}
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
// 检查嵌套的 response.model 字段
|
||||
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "response.model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
return line
|
||||
@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
return body
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubOpenAIAccountRepo struct {
|
||||
@@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 1, result.usage.InputTokens)
|
||||
require.Equal(t, 2, result.usage.OutputTokens)
|
||||
require.Equal(t, 3, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
||||
t.Fatalf("expected non-allowlisted host to fail")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
||||
|
||||
func TestReplaceModelInSSELine(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
from string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "顶层 model 字段替换",
|
||||
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-custom-model",
|
||||
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "嵌套 response.model 替换",
|
||||
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
|
||||
},
|
||||
{
|
||||
name: "model 不匹配时不替换",
|
||||
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "无 model 字段时不替换",
|
||||
line: `data: {"id":"chatcmpl-123","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "空 data 行",
|
||||
line: `data: `,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: `,
|
||||
},
|
||||
{
|
||||
name: "[DONE] 行",
|
||||
line: `data: [DONE]`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: [DONE]`,
|
||||
},
|
||||
{
|
||||
name: "非 data: 前缀行",
|
||||
line: `event: message`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `event: message`,
|
||||
},
|
||||
{
|
||||
name: "非法 JSON 不替换",
|
||||
line: `data: {invalid json}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {invalid json}`,
|
||||
},
|
||||
{
|
||||
name: "无空格 data: 格式",
|
||||
line: `data:{"id":"x","model":"gpt-4o"}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"id":"x","model":"my-model"}`,
|
||||
},
|
||||
{
|
||||
name: "model 名含特殊字符",
|
||||
line: `data: {"model":"org/model-v2.1-beta"}`,
|
||||
from: "org/model-v2.1-beta",
|
||||
to: "custom/alias",
|
||||
expected: `data: {"model":"custom/alias"}`,
|
||||
},
|
||||
{
|
||||
name: "空行",
|
||||
line: "",
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "保持其他字段不变",
|
||||
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
||||
},
|
||||
{
|
||||
name: "顶层优先于嵌套:同时存在两个 model",
|
||||
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
|
||||
from: "gpt-4o",
|
||||
to: "replaced",
|
||||
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceModelInSSEBody(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
from string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "多行 SSE body 替换",
|
||||
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
||||
},
|
||||
{
|
||||
name: "无需替换的 body",
|
||||
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
||||
},
|
||||
{
|
||||
name: "混合 event 和 data 行",
|
||||
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
|
||||
},
|
||||
{
|
||||
name: "空 body",
|
||||
body: "",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceModelInResponseBody(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
from string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "替换顶层 model",
|
||||
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "model 不匹配不替换",
|
||||
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "无 model 字段不替换",
|
||||
body: `{"id":"chatcmpl-123","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"id":"chatcmpl-123","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "非法 JSON 返回原值",
|
||||
body: `not json`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `not json`,
|
||||
},
|
||||
{
|
||||
name: "空 body 返回原值",
|
||||
body: ``,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: ``,
|
||||
},
|
||||
{
|
||||
name: "保持嵌套结构不变",
|
||||
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
|
||||
require.Equal(t, tt.expected, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
24
backend/internal/service/sse_scanner_buffer_pool.go
Normal file
24
backend/internal/service/sse_scanner_buffer_pool.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package service
|
||||
|
||||
import "sync"
|
||||
|
||||
const sseScannerBuf64KSize = 64 * 1024
|
||||
|
||||
type sseScannerBuf64K [sseScannerBuf64KSize]byte
|
||||
|
||||
var sseScannerBuf64KPool = sync.Pool{
|
||||
New: func() any {
|
||||
return new(sseScannerBuf64K)
|
||||
},
|
||||
}
|
||||
|
||||
func getSSEScannerBuf64K() *sseScannerBuf64K {
|
||||
return sseScannerBuf64KPool.Get().(*sseScannerBuf64K)
|
||||
}
|
||||
|
||||
func putSSEScannerBuf64K(buf *sseScannerBuf64K) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
sseScannerBuf64KPool.Put(buf)
|
||||
}
|
||||
19
backend/internal/service/sse_scanner_buffer_pool_test.go
Normal file
19
backend/internal/service/sse_scanner_buffer_pool_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) {
|
||||
buf := getSSEScannerBuf64K()
|
||||
require.NotNil(t, buf)
|
||||
require.Equal(t, sseScannerBuf64KSize, len(buf[:]))
|
||||
|
||||
buf[0] = 1
|
||||
putSSEScannerBuf64K(buf)
|
||||
|
||||
// 允许传入 nil,确保不会 panic
|
||||
putSSEScannerBuf64K(nil)
|
||||
}
|
||||
@@ -4,10 +4,15 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
|
||||
@@ -35,15 +40,76 @@ type SubscriptionService struct {
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
|
||||
// L1 缓存:加速中间件热路径的订阅查询
|
||||
subCacheL1 *ristretto.Cache
|
||||
subCacheGroup singleflight.Group
|
||||
subCacheTTL time.Duration
|
||||
subCacheJitter int // 抖动百分比
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
return &SubscriptionService{
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
|
||||
svc := &SubscriptionService{
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
svc.initSubCache(cfg)
|
||||
return svc
|
||||
}
|
||||
|
||||
// initSubCache 初始化订阅 L1 缓存
|
||||
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
sc := cfg.SubscriptionCache
|
||||
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: int64(sc.L1Size) * 10,
|
||||
MaxCost: int64(sc.L1Size),
|
||||
BufferItems: 64,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
|
||||
return
|
||||
}
|
||||
s.subCacheL1 = cache
|
||||
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
|
||||
s.subCacheJitter = sc.JitterPercent
|
||||
}
|
||||
|
||||
// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销)
|
||||
func subCacheKey(userID, groupID int64) string {
|
||||
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
|
||||
}
|
||||
|
||||
// jitteredTTL 为 TTL 添加抖动,避免集中过期
|
||||
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 || s.subCacheJitter <= 0 {
|
||||
return ttl
|
||||
}
|
||||
pct := s.subCacheJitter
|
||||
if pct > 100 {
|
||||
pct = 100
|
||||
}
|
||||
delta := float64(pct) / 100
|
||||
factor := 1 - delta + rand.Float64()*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
}
|
||||
return time.Duration(float64(ttl) * factor)
|
||||
}
|
||||
|
||||
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
|
||||
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
|
||||
if s.subCacheL1 == nil {
|
||||
return
|
||||
}
|
||||
s.subCacheL1.Del(subCacheKey(userID, groupID))
|
||||
}
|
||||
|
||||
// AssignSubscriptionInput 分配订阅输入
|
||||
@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := sub.UserID, sub.GroupID
|
||||
go func() {
|
||||
@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := sub.UserID, sub.GroupID
|
||||
go func() {
|
||||
@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
// 使用 L1 缓存 + singleflight 加速中间件热路径。
|
||||
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
key := subCacheKey(userID, groupID)
|
||||
|
||||
// L1 缓存命中:返回浅拷贝
|
||||
if s.subCacheL1 != nil {
|
||||
if v, ok := s.subCacheL1.Get(key); ok {
|
||||
if sub, ok := v.(*UserSubscription); ok {
|
||||
cp := *sub
|
||||
return &cp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return sub, nil
|
||||
|
||||
// singleflight 防止并发击穿
|
||||
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
// 写入 L1 缓存
|
||||
if s.subCacheL1 != nil {
|
||||
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
|
||||
}
|
||||
return sub, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// singleflight 返回的也是缓存指针,需要浅拷贝
|
||||
cp := *value.(*UserSubscription)
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
@@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
|
||||
if needsInvalidateCache && s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
// 如果有窗口被重置,失效缓存以保持一致性
|
||||
if needsInvalidateCache {
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
|
||||
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
|
||||
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
|
||||
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
|
||||
// 1. 验证订阅状态
|
||||
if sub.Status == SubscriptionStatusExpired {
|
||||
return false, ErrSubscriptionExpired
|
||||
}
|
||||
if sub.Status == SubscriptionStatusSuspended {
|
||||
return false, ErrSubscriptionSuspended
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
return false, ErrSubscriptionExpired
|
||||
}
|
||||
|
||||
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
|
||||
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
|
||||
if sub.NeedsDailyReset() {
|
||||
sub.DailyUsageUSD = 0
|
||||
needsMaintenance = true
|
||||
}
|
||||
if sub.NeedsWeeklyReset() {
|
||||
sub.WeeklyUsageUSD = 0
|
||||
needsMaintenance = true
|
||||
}
|
||||
if sub.NeedsMonthlyReset() {
|
||||
sub.MonthlyUsageUSD = 0
|
||||
needsMaintenance = true
|
||||
}
|
||||
if !sub.IsWindowActivated() {
|
||||
needsMaintenance = true
|
||||
}
|
||||
|
||||
// 3. 检查用量限额
|
||||
if !sub.CheckDailyLimit(group, 0) {
|
||||
return needsMaintenance, ErrDailyLimitExceeded
|
||||
}
|
||||
if !sub.CheckWeeklyLimit(group, 0) {
|
||||
return needsMaintenance, ErrWeeklyLimitExceeded
|
||||
}
|
||||
if !sub.CheckMonthlyLimit(group, 0) {
|
||||
return needsMaintenance, ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return needsMaintenance, nil
|
||||
}
|
||||
|
||||
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
|
||||
// 使用独立 context,不受请求取消影响。
|
||||
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
|
||||
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
|
||||
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
|
||||
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 激活窗口(首次使用时)
|
||||
if !sub.IsWindowActivated() {
|
||||
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 重置过期窗口
|
||||
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 失效 L1 缓存,确保后续请求拿到更新后的数据
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量到订阅
|
||||
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
|
||||
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
|
||||
@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
|
||||
}
|
||||
|
||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
|
||||
type UserService struct {
|
||||
userRepo UserRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
billingCache BillingCache
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
|
||||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
billingCache: billingCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCache != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
186
backend/internal/service/user_service_test.go
Normal file
186
backend/internal/service/user_service_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock: UserRepository ---
|
||||
|
||||
type mockUserRepo struct {
|
||||
updateBalanceErr error
|
||||
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
if m.updateBalanceFn != nil {
|
||||
return m.updateBalanceFn(ctx, id, amount)
|
||||
}
|
||||
return m.updateBalanceErr
|
||||
}
|
||||
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
// --- mock: APIKeyAuthCacheInvalidator ---
|
||||
|
||||
type mockAuthCacheInvalidator struct {
|
||||
invalidatedUserIDs []int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
|
||||
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
|
||||
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||||
}
|
||||
|
||||
// --- mock: BillingCache ---
|
||||
|
||||
type mockBillingCache struct {
|
||||
invalidateErr error
|
||||
invalidateCallCount atomic.Int64
|
||||
invalidatedUserIDs []int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
|
||||
func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
|
||||
func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
|
||||
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
|
||||
m.invalidateCallCount.Add(1)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||||
return m.invalidateErr
|
||||
}
|
||||
func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 测试 ---
|
||||
|
||||
func TestUpdateBalance_Success(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
cache := &mockBillingCache{}
|
||||
svc := NewUserService(repo, nil, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 42, 100.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 等待异步 goroutine 完成
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.invalidateCallCount.Load() == 1
|
||||
}, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance")
|
||||
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
svc := NewUserService(repo, nil, nil) // billingCache = nil
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 1, 50.0)
|
||||
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
|
||||
svc := NewUserService(repo, nil, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 99, 200.0)
|
||||
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
|
||||
|
||||
// 等待异步 goroutine 完成(即使失败也应调用)
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.invalidateCallCount.Load() == 1
|
||||
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
|
||||
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
|
||||
cache := &mockBillingCache{}
|
||||
svc := NewUserService(repo, nil, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 1, 100.0)
|
||||
require.Error(t, err, "repo 失败时应返回错误")
|
||||
require.Contains(t, err.Error(), "update balance")
|
||||
|
||||
// repo 失败时不应触发缓存失效
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.Equal(t, int64(0), cache.invalidateCallCount.Load(),
|
||||
"repo 失败时不应调用 InvalidateUserBalance")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
auth := &mockAuthCacheInvalidator{}
|
||||
cache := &mockBillingCache{}
|
||||
svc := NewUserService(repo, auth, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 77, 300.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 auth cache 同步失效
|
||||
auth.mu.Lock()
|
||||
require.Equal(t, []int64{77}, auth.invalidatedUserIDs)
|
||||
auth.mu.Unlock()
|
||||
|
||||
// 验证 billing cache 异步失效
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.invalidateCallCount.Load() == 1
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestNewUserService_FieldsAssignment(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
auth := &mockAuthCacheInvalidator{}
|
||||
cache := &mockBillingCache{}
|
||||
|
||||
svc := NewUserService(repo, auth, cache)
|
||||
require.NotNil(t, svc)
|
||||
require.Equal(t, repo, svc.userRepo)
|
||||
require.Equal(t, auth, svc.authCacheInvalidator)
|
||||
require.Equal(t, cache, svc.billingCache)
|
||||
}
|
||||
@@ -58,13 +58,67 @@ TZ=Asia/Shanghai
|
||||
POSTGRES_USER=sub2api
|
||||
POSTGRES_PASSWORD=change_this_secure_password
|
||||
POSTGRES_DB=sub2api
|
||||
# PostgreSQL 监听端口(同时用于 PG 服务端和应用连接,默认 5432)
|
||||
DATABASE_PORT=5432
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# PostgreSQL 服务端参数(可选;主要用于 deploy/docker-compose-aicodex.yml)
|
||||
# -----------------------------------------------------------------------------
|
||||
# POSTGRES_MAX_CONNECTIONS:PostgreSQL 服务端允许的最大连接数。
|
||||
# 必须 >=(所有 Sub2API 实例的 DATABASE_MAX_OPEN_CONNS 之和)+ 预留余量(例如 20%)。
|
||||
POSTGRES_MAX_CONNECTIONS=1024
|
||||
# POSTGRES_SHARED_BUFFERS:PostgreSQL 用于缓存数据页的共享内存。
|
||||
# 常见建议:物理内存的 10%~25%(容器内存受限时请按实际限制调整)。
|
||||
# 8GB 内存容器参考:1GB。
|
||||
POSTGRES_SHARED_BUFFERS=1GB
|
||||
# POSTGRES_EFFECTIVE_CACHE_SIZE:查询规划器“假设可用的 OS 缓存大小”(不等于实际分配)。
|
||||
# 常见建议:物理内存的 50%~75%。
|
||||
# 8GB 内存容器参考:6GB。
|
||||
POSTGRES_EFFECTIVE_CACHE_SIZE=4GB
|
||||
# POSTGRES_MAINTENANCE_WORK_MEM:维护操作内存(VACUUM/CREATE INDEX 等)。
|
||||
# 值越大维护越快,但会占用更多内存。
|
||||
# 8GB 内存容器参考:128MB。
|
||||
POSTGRES_MAINTENANCE_WORK_MEM=128MB
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# PostgreSQL 连接池参数(可选,默认与程序内置一致)
|
||||
# -----------------------------------------------------------------------------
|
||||
# 说明:
|
||||
# - 这些参数控制 Sub2API 进程到 PostgreSQL 的连接池大小(不是 PostgreSQL 自身的 max_connections)。
|
||||
# - 多实例/多副本部署时,总连接上限约等于:实例数 * DATABASE_MAX_OPEN_CONNS。
|
||||
# - 连接池过大可能导致:数据库连接耗尽、内存占用上升、上下文切换增多,反而变慢。
|
||||
# - 建议结合 PostgreSQL 的 max_connections 与机器规格逐步调优:
|
||||
# 通常把应用总连接上限控制在 max_connections 的 50%~80% 更稳妥。
|
||||
#
|
||||
# DATABASE_MAX_OPEN_CONNS:最大打开连接数(活跃+空闲),达到后新请求会等待可用连接。
|
||||
# 典型范围:50~500(取决于 DB 规格、实例数、SQL 复杂度)。
|
||||
DATABASE_MAX_OPEN_CONNS=256
|
||||
# DATABASE_MAX_IDLE_CONNS:最大空闲连接数(热连接),建议 <= MAX_OPEN。
|
||||
# 太小会频繁建连增加延迟;太大会长期占用数据库资源。
|
||||
DATABASE_MAX_IDLE_CONNS=128
|
||||
# DATABASE_CONN_MAX_LIFETIME_MINUTES:单个连接最大存活时间(单位:分钟)。
|
||||
# 用于避免连接长期不重建导致的中间件/LB/NAT 异常或服务端重启后的“僵尸连接”。
|
||||
# 设置为 0 表示不限制(一般不建议生产环境)。
|
||||
DATABASE_CONN_MAX_LIFETIME_MINUTES=30
|
||||
# DATABASE_CONN_MAX_IDLE_TIME_MINUTES:空闲连接最大存活时间(单位:分钟)。
|
||||
# 超过该时间的空闲连接会被回收,防止长时间闲置占用连接数。
|
||||
# 设置为 0 表示不限制(一般不建议生产环境)。
|
||||
DATABASE_CONN_MAX_IDLE_TIME_MINUTES=5
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis 监听端口(同时用于应用连接和 Redis 服务端,默认 6379)
|
||||
REDIS_PORT=6379
|
||||
# Leave empty for no password (default for local development)
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
# Redis 服务端最大客户端连接数(可选;主要用于 deploy/docker-compose-aicodex.yml)
|
||||
REDIS_MAXCLIENTS=50000
|
||||
# Redis 连接池大小(默认 1024)
|
||||
REDIS_POOL_SIZE=4096
|
||||
# Redis 最小空闲连接数(默认 10)
|
||||
REDIS_MIN_IDLE_CONNS=256
|
||||
REDIS_ENABLE_TLS=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -119,6 +173,19 @@ RATE_LIMIT_OVERLOAD_COOLDOWN_MINUTES=10
|
||||
# Gateway Scheduling (Optional)
|
||||
# 调度缓存与受控回源配置(缓存就绪且命中时不读 DB)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI.
|
||||
# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。
|
||||
#
|
||||
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
|
||||
#
|
||||
# 默认:false
|
||||
GATEWAY_FORCE_CODEX_CLI=false
|
||||
# 上游连接池:每主机最大连接数(默认 1024;流式/HTTP1.1 场景可调大,如 2400/4096)
|
||||
GATEWAY_MAX_CONNS_PER_HOST=2048
|
||||
# 上游连接池:最大空闲连接总数(默认 2560;账号/代理隔离 + 高并发场景可调大)
|
||||
GATEWAY_MAX_IDLE_CONNS=8192
|
||||
# 上游连接池:每主机最大空闲连接(默认 120)
|
||||
GATEWAY_MAX_IDLE_CONNS_PER_HOST=4096
|
||||
# 粘性会话最大排队长度
|
||||
GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING=3
|
||||
# 粘性会话等待超时(时间段,例如 45s)
|
||||
|
||||
@@ -20,6 +20,10 @@ server:
|
||||
# Mode: "debug" for development, "release" for production
|
||||
# 运行模式:"debug" 用于开发,"release" 用于生产环境
|
||||
mode: "release"
|
||||
# Frontend base URL used to generate external links in emails (e.g. password reset)
|
||||
# 用于生成邮件中的外部链接(例如:重置密码链接)的前端基础地址
|
||||
# Example: "https://example.com"
|
||||
frontend_url: ""
|
||||
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
|
||||
# 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。
|
||||
trusted_proxies: []
|
||||
@@ -108,9 +112,9 @@ security:
|
||||
# 白名单禁用时是否允许 http:// URL(默认: false,要求 https)
|
||||
allow_insecure_http: true
|
||||
response_headers:
|
||||
# Enable configurable response header filtering (disable to use default allowlist)
|
||||
# 启用可配置的响应头过滤(禁用则使用默认白名单)
|
||||
enabled: false
|
||||
# Enable configurable response header filtering (default: true)
|
||||
# 启用可配置的响应头过滤(默认启用,过滤上游敏感响应头)
|
||||
enabled: true
|
||||
# Extra allowed response headers from upstream
|
||||
# 额外允许的上游响应头
|
||||
additional_allowed: []
|
||||
@@ -151,17 +155,22 @@ gateway:
|
||||
# - account_proxy: Isolate by account+proxy combination (default, finest granularity)
|
||||
# - account_proxy: 按账户+代理组合隔离(默认,最细粒度)
|
||||
connection_pool_isolation: "account_proxy"
|
||||
# Force Codex CLI mode: treat all /openai/v1/responses requests as Codex CLI.
|
||||
# 强制按 Codex CLI 处理 /openai/v1/responses 请求(用于网关未透传/改写 User-Agent 的兜底)。
|
||||
#
|
||||
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
|
||||
force_codex_cli: false
|
||||
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
|
||||
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
|
||||
# Max idle connections across all hosts
|
||||
# 所有主机的最大空闲连接数
|
||||
max_idle_conns: 240
|
||||
max_idle_conns: 2560
|
||||
# Max idle connections per host
|
||||
# 每个主机的最大空闲连接数
|
||||
max_idle_conns_per_host: 120
|
||||
# Max connections per host
|
||||
# 每个主机的最大连接数
|
||||
max_conns_per_host: 240
|
||||
max_conns_per_host: 1024
|
||||
# Idle connection timeout (seconds)
|
||||
# 空闲连接超时时间(秒)
|
||||
idle_conn_timeout_seconds: 90
|
||||
@@ -381,9 +390,22 @@ database:
|
||||
# Database name
|
||||
# 数据库名称
|
||||
dbname: "sub2api"
|
||||
# SSL mode: disable, require, verify-ca, verify-full
|
||||
# SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证)
|
||||
sslmode: "disable"
|
||||
# SSL mode: disable, prefer, require, verify-ca, verify-full
|
||||
# SSL 模式:disable(禁用), prefer(优先加密,默认), require(要求), verify-ca(验证CA), verify-full(完全验证)
|
||||
# 默认值为 "prefer",数据库支持 SSL 时自动使用加密连接,不支持时回退明文
|
||||
sslmode: "prefer"
|
||||
# Max open connections (高并发场景建议 256+,需配合 PostgreSQL max_connections 调整)
|
||||
# 最大打开连接数
|
||||
max_open_conns: 256
|
||||
# Max idle connections (建议为 max_open_conns 的 50%,减少频繁建连开销)
|
||||
# 最大空闲连接数
|
||||
max_idle_conns: 128
|
||||
# Connection max lifetime (minutes)
|
||||
# 连接最大存活时间(分钟)
|
||||
conn_max_lifetime_minutes: 30
|
||||
# Connection max idle time (minutes)
|
||||
# 空闲连接最大存活时间(分钟)
|
||||
conn_max_idle_time_minutes: 5
|
||||
|
||||
# =============================================================================
|
||||
# Redis Configuration
|
||||
@@ -402,6 +424,12 @@ redis:
|
||||
# Database number (0-15)
|
||||
# 数据库编号(0-15)
|
||||
db: 0
|
||||
# Connection pool size (max concurrent connections)
|
||||
# 连接池大小(最大并发连接数)
|
||||
pool_size: 1024
|
||||
# Minimum number of idle connections (高并发场景建议 128+,保持足够热连接)
|
||||
# 最小空闲连接数
|
||||
min_idle_conns: 128
|
||||
# Enable TLS/SSL connection
|
||||
# 是否启用 TLS/SSL 连接
|
||||
enable_tls: false
|
||||
|
||||
233
deploy/docker-compose-aicodex.yml
Normal file
233
deploy/docker-compose-aicodex.yml
Normal file
@@ -0,0 +1,233 @@
|
||||
# =============================================================================
|
||||
# Sub2API Docker Compose Host Configuration (Local Build)
|
||||
# =============================================================================
|
||||
# Quick Start:
|
||||
# 1. Copy .env.example to .env and configure
|
||||
# 2. docker-compose -f docker-compose-host.yml up -d --build
|
||||
# 3. Check logs: docker-compose -f docker-compose-host.yml logs -f sub2api
|
||||
# 4. Access: http://localhost:8080
|
||||
#
|
||||
# This configuration builds the image from source (Dockerfile in project root).
|
||||
# All configuration is done via environment variables.
|
||||
# No Setup Wizard needed - the system auto-initializes on first run.
|
||||
# =============================================================================
|
||||
|
||||
services:
|
||||
# ===========================================================================
|
||||
# Sub2API Application
|
||||
# ===========================================================================
|
||||
sub2api:
|
||||
#image: weishaw/sub2api:latest
|
||||
image: yangjianbo/aicodex2api:latest
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Dockerfile
|
||||
container_name: sub2api
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
ulimits:
|
||||
nofile:
|
||||
soft: 800000
|
||||
hard: 800000
|
||||
volumes:
|
||||
# Data persistence (config.yaml will be auto-generated here)
|
||||
- sub2api_data:/app/data
|
||||
# Mount custom config.yaml (optional, overrides auto-generated config)
|
||||
#- ./config.yaml:/app/data/config.yaml:ro
|
||||
environment:
|
||||
# =======================================================================
|
||||
# Auto Setup (REQUIRED for Docker deployment)
|
||||
# =======================================================================
|
||||
- AUTO_SETUP=true
|
||||
|
||||
# =======================================================================
|
||||
# Server Configuration
|
||||
# =======================================================================
|
||||
- SERVER_HOST=0.0.0.0
|
||||
- SERVER_PORT=8080
|
||||
- SERVER_MODE=${SERVER_MODE:-release}
|
||||
- RUN_MODE=${RUN_MODE:-standard}
|
||||
|
||||
# =======================================================================
|
||||
# Database Configuration (PostgreSQL)
|
||||
# =======================================================================
|
||||
# Using host network: point to host/external DB by DATABASE_HOST/DATABASE_PORT
|
||||
- DATABASE_HOST=${DATABASE_HOST:-127.0.0.1}
|
||||
- DATABASE_PORT=${DATABASE_PORT:-5432}
|
||||
- DATABASE_USER=${POSTGRES_USER:-sub2api}
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
|
||||
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
|
||||
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
|
||||
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
|
||||
|
||||
# =======================================================================
|
||||
# Gateway Configuration
|
||||
# =======================================================================
|
||||
- GATEWAY_FORCE_CODEX_CLI=${GATEWAY_FORCE_CODEX_CLI:-false}
|
||||
- GATEWAY_MAX_IDLE_CONNS=${GATEWAY_MAX_IDLE_CONNS:-2560}
|
||||
- GATEWAY_MAX_IDLE_CONNS_PER_HOST=${GATEWAY_MAX_IDLE_CONNS_PER_HOST:-120}
|
||||
- GATEWAY_MAX_CONNS_PER_HOST=${GATEWAY_MAX_CONNS_PER_HOST:-8192}
|
||||
|
||||
# =======================================================================
|
||||
# Redis Configuration
|
||||
# =======================================================================
|
||||
# Using host network: point to host/external Redis by REDIS_HOST/REDIS_PORT
|
||||
- REDIS_HOST=${REDIS_HOST:-127.0.0.1}
|
||||
- REDIS_PORT=${REDIS_PORT:-6379}
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
|
||||
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
|
||||
- REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
|
||||
|
||||
# =======================================================================
|
||||
# Admin Account (auto-created on first run)
|
||||
# =======================================================================
|
||||
- ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local}
|
||||
- ADMIN_PASSWORD=${ADMIN_PASSWORD:-}
|
||||
|
||||
# =======================================================================
|
||||
# JWT Configuration
|
||||
# =======================================================================
|
||||
# Leave empty to auto-generate (recommended)
|
||||
- JWT_SECRET=${JWT_SECRET:-}
|
||||
- JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24}
|
||||
|
||||
# =======================================================================
|
||||
# TOTP (2FA) Configuration
|
||||
# =======================================================================
|
||||
# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty,
|
||||
# a random key will be generated on each startup, causing all existing
|
||||
# TOTP configurations to become invalid (users won't be able to login
|
||||
# with 2FA).
|
||||
# Generate a secure key: openssl rand -hex 32
|
||||
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
|
||||
|
||||
# =======================================================================
|
||||
# Timezone Configuration
|
||||
# This affects ALL time operations in the application:
|
||||
# - Database timestamps
|
||||
# - Usage statistics "today" boundary
|
||||
# - Subscription expiry times
|
||||
# - Log timestamps
|
||||
# Common values: Asia/Shanghai, America/New_York, Europe/London, UTC
|
||||
# =======================================================================
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
|
||||
# =======================================================================
|
||||
# Gemini OAuth Configuration (for Gemini accounts)
|
||||
# =======================================================================
|
||||
- GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-}
|
||||
- GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-}
|
||||
- GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-}
|
||||
- GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-}
|
||||
|
||||
# =======================================================================
|
||||
# Security Configuration (URL Allowlist)
|
||||
# =======================================================================
|
||||
# Allow private IP addresses for CRS sync (for internal deployments)
|
||||
- SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS=${SECURITY_URL_ALLOWLIST_ALLOW_PRIVATE_HOSTS:-true}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
# ===========================================================================
|
||||
# PostgreSQL Database
|
||||
# ===========================================================================
|
||||
postgres:
|
||||
image: postgres:18-alpine
|
||||
container_name: sub2api-postgres
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
ulimits:
|
||||
nofile:
|
||||
soft: 800000
|
||||
hard: 800000
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
command:
|
||||
- "postgres"
|
||||
- "-c"
|
||||
- "listen_addresses=127.0.0.1"
|
||||
# 监听端口:与应用侧 DATABASE_PORT 保持一致。
|
||||
- "-c"
|
||||
- "port=${DATABASE_PORT:-5432}"
|
||||
# 连接数上限:需要结合应用侧 DATABASE_MAX_OPEN_CONNS 调整。
|
||||
# 注意:max_connections 过大可能导致内存占用与上下文切换开销显著上升。
|
||||
- "-c"
|
||||
- "max_connections=${POSTGRES_MAX_CONNECTIONS:-1024}"
|
||||
# 典型内存参数(建议结合机器内存调优;不确定就保持默认或小步调大)。
|
||||
- "-c"
|
||||
- "shared_buffers=${POSTGRES_SHARED_BUFFERS:-1GB}"
|
||||
- "-c"
|
||||
- "effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-6GB}"
|
||||
- "-c"
|
||||
- "maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-128MB}"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api} -p ${DATABASE_PORT:-5432}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
# Note: bound to localhost only; not exposed to external network by default.
|
||||
|
||||
# ===========================================================================
|
||||
# Redis Cache
|
||||
# ===========================================================================
|
||||
redis:
|
||||
image: redis:8-alpine
|
||||
container_name: sub2api-redis
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
ulimits:
|
||||
nofile:
|
||||
soft: 100000
|
||||
hard: 100000
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
command: >
|
||||
redis-server
|
||||
--bind 127.0.0.1
|
||||
--port ${REDIS_PORT:-6379}
|
||||
--maxclients ${REDIS_MAXCLIENTS:-50000}
|
||||
--save 60 1
|
||||
--appendonly yes
|
||||
--appendfsync everysec
|
||||
${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}}
|
||||
environment:
|
||||
- TZ=${TZ:-Asia/Shanghai}
|
||||
# REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag)
|
||||
- REDISCLI_AUTH=${REDIS_PASSWORD:-}
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "redis-cli -p ${REDIS_PORT:-6379} -a \"$REDISCLI_AUTH\" ping | grep -q PONG || redis-cli -p ${REDIS_PORT:-6379} ping | grep -q PONG"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 5s
|
||||
|
||||
# =============================================================================
|
||||
# Volumes
|
||||
# =============================================================================
|
||||
volumes:
|
||||
sub2api_data:
|
||||
driver: local
|
||||
postgres_data:
|
||||
driver: local
|
||||
redis_data:
|
||||
driver: local
|
||||
@@ -57,6 +57,10 @@ services:
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
|
||||
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
|
||||
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
|
||||
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
|
||||
|
||||
# =======================================================================
|
||||
# Redis Configuration
|
||||
@@ -65,6 +69,8 @@ services:
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
|
||||
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
|
||||
|
||||
# =======================================================================
|
||||
# Admin Account (auto-created on first run)
|
||||
|
||||
@@ -62,6 +62,10 @@ services:
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
|
||||
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
|
||||
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
|
||||
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
|
||||
|
||||
# =======================================================================
|
||||
# Redis Configuration
|
||||
@@ -70,6 +74,8 @@ services:
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
|
||||
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
|
||||
- REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
|
||||
|
||||
# =======================================================================
|
||||
|
||||
@@ -48,6 +48,10 @@ services:
|
||||
- DATABASE_PASSWORD=${DATABASE_PASSWORD:?DATABASE_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${DATABASE_DBNAME:-sub2api}
|
||||
- DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable}
|
||||
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
|
||||
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
|
||||
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
|
||||
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
|
||||
|
||||
# =======================================================================
|
||||
# Redis Configuration - Required
|
||||
@@ -56,6 +60,8 @@ services:
|
||||
- REDIS_PORT=${REDIS_PORT:-6379}
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
|
||||
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
|
||||
- REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
|
||||
|
||||
# =======================================================================
|
||||
|
||||
@@ -54,6 +54,10 @@ services:
|
||||
- DATABASE_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
|
||||
- DATABASE_DBNAME=${POSTGRES_DB:-sub2api}
|
||||
- DATABASE_SSLMODE=disable
|
||||
- DATABASE_MAX_OPEN_CONNS=${DATABASE_MAX_OPEN_CONNS:-50}
|
||||
- DATABASE_MAX_IDLE_CONNS=${DATABASE_MAX_IDLE_CONNS:-10}
|
||||
- DATABASE_CONN_MAX_LIFETIME_MINUTES=${DATABASE_CONN_MAX_LIFETIME_MINUTES:-30}
|
||||
- DATABASE_CONN_MAX_IDLE_TIME_MINUTES=${DATABASE_CONN_MAX_IDLE_TIME_MINUTES:-5}
|
||||
|
||||
# =======================================================================
|
||||
# Redis Configuration
|
||||
@@ -62,6 +66,8 @@ services:
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=${REDIS_PASSWORD:-}
|
||||
- REDIS_DB=${REDIS_DB:-0}
|
||||
- REDIS_POOL_SIZE=${REDIS_POOL_SIZE:-1024}
|
||||
- REDIS_MIN_IDLE_CONNS=${REDIS_MIN_IDLE_CONNS:-10}
|
||||
- REDIS_ENABLE_TLS=${REDIS_ENABLE_TLS:-false}
|
||||
|
||||
# =======================================================================
|
||||
|
||||
Reference in New Issue
Block a user