diff --git a/README_CN.md b/README_CN.md index 87e787f8..2d5d4760 100644 --- a/README_CN.md +++ b/README_CN.md @@ -406,6 +406,14 @@ gateway: - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For - `turnstile.required` 在 release 模式强制启用 Turnstile +**网关防御纵深建议(重点)** + +- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。 +- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。 +- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。 +- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。 +- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。 + **⚠️ 安全警告:HTTP URL 配置** 当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置: diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 861351de..b9f31ba9 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -308,6 +308,12 @@ type GatewayConfig struct { ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` + // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 + UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` + // 代理探测响应体读取上限(字节) + ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` + // Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销) + GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 @@ -1059,6 +1065,9 @@ func setDefaults() { viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) + viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) viper.SetDefault("gateway.sora_request_timeout_seconds", 180) @@ -1465,6 +1474,12 @@ func (c *Config) Validate() error { if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.UpstreamResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive") + } + if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") + } if c.Gateway.SoraMaxBodySize < 0 { return fmt.Errorf("gateway.sora_max_body_size must be non-negative") } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index c031d6d6..fb2c5b2a 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1106,7 +1106,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { return } - response.Success(c, gin.H{"message": "Rate limit cleared successfully"}) + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.AccountFromService(account)) } // GetTempUnschedulable handles getting temporary unschedulable status diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 598eb4b3..ca5ee9d7 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } continue } - // 错误响应已在Forward中处理,这里只记录日志 - reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } continue } - // 错误响应已在Forward中处理,这里只记录日志 - reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ diff --git a/backend/internal/handler/gateway_handler_error_fallback_test.go b/backend/internal/handler/gateway_handler_error_fallback_test.go new file mode 100644 index 00000000..4fce5ec1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_error_fallback_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + assert.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 470eab45..f5db385b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) continue } - // Error response already handled in Forward, just log - reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("openai.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 65296da4..1ca52c2d 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { assert.Equal(t, "test error", errorObj["message"]) } +func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + // TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 func TestOpenAIHandler_GjsonExtraction(t *testing.T) { tests := []struct { diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 6ab2ff72..3f05ac41 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { return normalizeIP(c.ClientIP()) } +// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。 +// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。 +// 适用于 ACL / 风控等安全敏感场景。 +func GetTrustedClientIP(c *gin.Context) string { + if c == nil { + return "" + } + return normalizeIP(c.ClientIP()) +} + // normalizeIP 规范化 IP 地址,去除端口号和空格。 func normalizeIP(ip string) string { ip = strings.TrimSpace(ip) diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go index c3c90c74..3839403c 100644 --- a/backend/internal/pkg/ip/ip_test.go +++ b/backend/internal/pkg/ip/ip_test.go @@ -3,8 +3,10 @@ package ip import ( + "net/http/httptest" "testing" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) { }) } } + +func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + require.NoError(t, r.SetTrustedProxies(nil)) + + r.GET("/t", func(c *gin.Context) { + c.String(200, GetTrustedClientIP(c)) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + r.ServeHTTP(w, req) + + require.Equal(t, 200, w.Code) + require.Equal(t, "9.9.9.9", w.Body.String()) +} diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 2705d429..989573f2 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -1194,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } // Keep list endpoints scoped to client errors unless explicitly filtering upstream phase. if phaseFilter != "upstream" { - clauses = append(clauses, "COALESCE(status_code, 0) >= 400") + clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400") } if filter.StartTime != nil && !filter.StartTime.IsZero() { @@ -1208,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } if p := strings.TrimSpace(filter.Platform); p != "" { args = append(args, p) - clauses = append(clauses, "platform = $"+itoa(len(args))) + clauses = append(clauses, "e.platform = $"+itoa(len(args))) } if filter.GroupID != nil && *filter.GroupID > 0 { args = append(args, *filter.GroupID) - clauses = append(clauses, "group_id = $"+itoa(len(args))) + clauses = append(clauses, "e.group_id = $"+itoa(len(args))) } if filter.AccountID != nil && *filter.AccountID > 0 { args = append(args, *filter.AccountID) - clauses = append(clauses, "account_id = $"+itoa(len(args))) + clauses = append(clauses, "e.account_id = $"+itoa(len(args))) } if phase := phaseFilter; phase != "" { args = append(args, phase) - clauses = append(clauses, "error_phase = $"+itoa(len(args))) + clauses = append(clauses, "e.error_phase = $"+itoa(len(args))) } if filter != nil { if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" { args = append(args, owner) - clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args))) } if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" { args = append(args, source) - clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args))) } } if resolvedFilter != nil { args = append(args, *resolvedFilter) - clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args))) } // View filter: errors vs excluded vs all. @@ -1246,46 +1246,46 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } switch view { case "", "errors": - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") case "excluded": - clauses = append(clauses, "COALESCE(is_business_limited,false) = true") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true") case "all": // no-op default: // treat unknown as default 'errors' - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") } if len(filter.StatusCodes) > 0 { args = append(args, pq.Array(filter.StatusCodes)) - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")") + clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")") } else if filter.StatusCodesOther { // "Other" means: status codes not in the common list. known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529} args = append(args, pq.Array(known)) - clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))") + clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))") } // Exact correlation keys (preferred for request↔upstream linkage). if rid := strings.TrimSpace(filter.RequestID); rid != "" { args = append(args, rid) - clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args))) } if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" { args = append(args, crid) - clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args))) } if q := strings.TrimSpace(filter.Query); q != "" { like := "%" + q + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")") + clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")") } if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" { like := "%" + userQuery + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "u.email ILIKE $"+n) + clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")") } return "WHERE " + strings.Join(clauses, " AND "), args diff --git a/backend/internal/repository/ops_repo_error_where_test.go b/backend/internal/repository/ops_repo_error_where_test.go new file mode 100644 index 00000000..9ab1a89a --- /dev/null +++ b/backend/internal/repository/ops_repo_error_where_test.go @@ -0,0 +1,48 @@ +package repository + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + Query: "ACCESS_DENIED", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "e.request_id ILIKE $") { + t.Fatalf("where should include qualified request_id condition: %s", where) + } + if !strings.Contains(where, "e.client_request_id ILIKE $") { + t.Fatalf("where should include qualified client_request_id condition: %s", where) + } + if !strings.Contains(where, "e.error_message ILIKE $") { + t.Fatalf("where should include qualified error_message condition: %s", where) + } +} + +func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + UserQuery: "admin@", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") { + t.Fatalf("where should include EXISTS user email condition: %s", where) + } +} diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 513e929c..54de2897 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecure := false allowPrivate := false validateResolvedIP := true + maxResponseBytes := defaultProxyProbeResponseMaxBytes if cfg != nil { insecure = cfg.Security.ProxyProbe.InsecureSkipVerify allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts validateResolvedIP = cfg.Security.URLAllowlist.Enabled + if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 { + maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes + } } if insecure { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") @@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, + maxResponseBytes: maxResponseBytes, } } const ( - defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeResponseMaxBytes = int64(1024 * 1024) ) // probeURLs 按优先级排列的探测 URL 列表 @@ -52,6 +58,7 @@ type proxyProbeService struct { insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool + maxResponseBytes int64 } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { @@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) + maxResponseBytes := s.maxResponseBytes + if maxResponseBytes <= 0 { + maxResponseBytes = defaultProxyProbeResponseMaxBytes + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) if err != nil { return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) } + if int64(len(body)) > maxResponseBytes { + return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes) + } switch parser { case "ip-api": diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d2d8ed40..a8034e98 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -51,6 +51,9 @@ func ProvideRouter( if err := r.SetTrustedProxies(nil); err != nil { log.Printf("Failed to disable trusted proxies: %v", err) } + if cfg.Server.Mode == "release" { + log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled") + } } return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 8e03f785..7aad1699 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查 IP 限制(白名单/黑名单) // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { - clientIP := ip.GetClientIP(c) + clientIP := ip.GetTrustedClientIP(c) allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 3e33c7e3..f3a6f076 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + IPWhitelist: []string{"1.2.3.4"}, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := gin.New() + require.NoError(t, router.SetTrustedProxies(nil)) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "ACCESS_DENIED") +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 26d79605..c168820c 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -24,10 +24,19 @@ func RegisterAuthRoutes( // 公开接口 auth := v1.Group("/auth") { - auth.POST("/register", h.Auth.Register) - auth.POST("/login", h.Auth.Login) - auth.POST("/login/2fa", h.Auth.Login2FA) - auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) + auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Register) + auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login) + auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login2FA) + auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.SendVerifyCode) // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/server/routes/auth_rate_limit_integration_test.go b/backend/internal/server/routes/auth_rate_limit_integration_test.go new file mode 100644 index 00000000..8a0ef860 --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_integration_test.go @@ -0,0 +1,111 @@ +//go:build integration + +package routes + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" +) + +const authRouteRedisImageTag = "redis:8.4-alpine" + +func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) { + ctx := context.Background() + rdb := startAuthRouteRedis(t, ctx) + + router := newAuthRoutesTestRouter(rdb) + const path = "/api/v1/auth/register" + + for i := 1; i <= 6; i++ { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "198.51.100.10:23456" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if i <= 5 { + require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i) + continue + } + require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流") + require.Contains(t, w.Body.String(), "rate limit exceeded") + } +} + +func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client { + t.Helper() + ensureAuthRouteDockerAvailable(t) + + redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag) + require.NoError(t, err) + t.Cleanup(func() { + _ = redisContainer.Terminate(ctx) + }) + + redisHost, err := redisContainer.Host(ctx) + require.NoError(t, err) + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + require.NoError(t, err) + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + require.NoError(t, rdb.Ping(ctx).Err()) + t.Cleanup(func() { + _ = rdb.Close() + }) + return rdb +} + +func ensureAuthRouteDockerAvailable(t *testing.T) { + t.Helper() + if authRouteDockerAvailable() { + return + } + t.Skip("Docker 未启用,跳过认证限流集成测试") +} + +func authRouteDockerAvailable() bool { + if os.Getenv("DOCKER_HOST") != "" { + return true + } + + socketCandidates := []string{ + "/var/run/docker.sock", + filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"), + filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"), + } + + for _, socket := range socketCandidates { + if socket == "" { + continue + } + if _, err := os.Stat(socket); err == nil { + return true + } + } + return false +} + +func authRouteUserHomeDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return home +} diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go new file mode 100644 index 00000000..5ce8497c --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -0,0 +1,67 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + v1 := router.Group("/api/v1") + + RegisterAuthRoutes( + v1, + &handler.Handlers{ + Auth: &handler.AuthHandler{}, + Setting: &handler.SettingHandler{}, + }, + servermiddleware.JWTAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + redisClient, + ) + + return router +} + +func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + router := newAuthRoutesTestRouter(rdb) + paths := []string{ + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/auth/login/2fa", + "/api/v1/auth/send-verify-code", + } + + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.10:12345" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path) + require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path) + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 83cde19e..0502d352 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 不需要重试(成功或不可重试的错误),跳出循环 // DEBUG: 输出响应 headers(用于检测 rate limit 信息) - if account.Platform == PlatformGemini && resp.StatusCode < 400 { + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) for k, v := range resp.Header { logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) @@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - respBody, err := io.ReadAll(resp.Body) + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp - respBody, err = io.ReadAll(resp.Body) + respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2fe55137..8670f99a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct { } func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { - // Log response headers for debugging - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") - respBody, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } - var parsed map[string]any if isOAuth { unwrappedBody, uwErr := unwrapGeminiResponse(respBody) if uwErr == nil { respBody = unwrappedBody } - _ = json.Unmarshal(respBody, &parsed) - } else { - _ = json.Unmarshal(respBody, &parsed) } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { - // Log response headers for debugging - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") if s.cfg != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index c5888d88..7560f480 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -3,10 +3,15 @@ package service import ( "encoding/json" "fmt" + "io" + "net/http" + "net/http/httptest" "strings" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { } } +func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + svc := &GeminiMessagesCompatService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + GeminiDebugResponseHeaders: false, + }, + }, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-RateLimit-Limit": []string{"60"}, + }, + Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)), + } + + usage, err := svc.handleNativeNonStreamingResponse(c, resp, false) + require.NoError(t, err) + require.NotNil(t, usage) + require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 697be994..ac93aa0c 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -313,7 +313,6 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco } log := logger.FromContext(ctx).With(fields...) if result.Matched { - log.Warn("OpenAI codex_cli_only 允许官方客户端请求") return } log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") @@ -1277,6 +1276,29 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( startTime time.Time, ) (*OpenAIForwardResult, error) { if account != nil && account.Type == AccountTypeOAuth { + if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { + rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" + setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusForbidden, + Passthrough: true, + Kind: "request_error", + Message: rejectMsg, + Detail: rejectReason, + }) + logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": rejectMsg, + }, + }) + return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) + } + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) if err != nil { return nil, err @@ -1396,6 +1418,37 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( }, nil } +func logOpenAIPassthroughInstructionsRejected( + ctx context.Context, + c *gin.Context, + account *Account, + reqModel string, + rejectReason string, + body []byte, +) { + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + accountName := "" + accountType := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + accountType = strings.TrimSpace(string(account.Type)) + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.String("account_type", accountType), + zap.String("request_model", strings.TrimSpace(reqModel)), + zap.String("reject_reason", strings.TrimSpace(rejectReason)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") +} + func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( ctx context.Context, c *gin.Context, @@ -1688,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( resp *http.Response, c *gin.Context, ) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -2318,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -2877,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { return normalized, changed, nil } +func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { + model := strings.ToLower(strings.TrimSpace(reqModel)) + if !strings.Contains(model, "codex") { + return "" + } + + instructions := gjson.GetBytes(body, "instructions") + if !instructions.Exists() { + return "instructions_missing" + } + if instructions.Type != gjson.String { + return "instructions_not_string" + } + if strings.TrimSpace(instructions.String()) == "" { + return "instructions_empty" + } + return "" +} + func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) if reasoningEffort == "" { diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go index 3d7caf8b..d7c95ada 100644 --- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -103,7 +103,7 @@ func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) { }) } -func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) { +func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) { logSink, restore := captureStructuredLog(t) defer restore() @@ -119,7 +119,7 @@ func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) { Reason: CodexClientRestrictionReasonNotMatchedUA, }, nil) - require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求")) + require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求")) require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求")) } @@ -131,7 +131,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) - c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("OpenAI-Beta", "assistants=v2") @@ -143,7 +143,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { Reason: CodexClientRestrictionReasonNotMatchedUA, }, body) - require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2")) require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123"))) diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index f6a72610..49658d6d 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -164,7 +164,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali c.Request.Header.Set("Proxy-Authorization", "Basic abc") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) upstreamSSE := strings.Join([]string{ `data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`, @@ -211,6 +211,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali // 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。 require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String())) // 其余关键字段保持原值。 require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String()) require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) @@ -235,6 +236,59 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali require.NotContains(t, body, "\"name\":\"edit\"") } +func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "responses=experimental") + + // Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。 + originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "requires a non-empty instructions field") + require.Nil(t, upstream.lastReq) + + require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing")) +} + func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go new file mode 100644 index 00000000..aecf69a3 --- /dev/null +++ b/backend/internal/service/upstream_response_limit.go @@ -0,0 +1,38 @@ +package service + +import ( + "errors" + "fmt" + "io" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") + +const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 + +func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { + if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { + return cfg.Gateway.UpstreamResponseReadMaxBytes + } + return defaultUpstreamResponseReadMaxBytes +} + +func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) { + if reader == nil { + return nil, errors.New("response body is nil") + } + if maxBytes <= 0 { + maxBytes = defaultUpstreamResponseReadMaxBytes + } + + body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1)) + if err != nil { + return nil, err + } + if int64(len(body)) > maxBytes { + return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes) + } + return body, nil +} diff --git a/backend/internal/service/upstream_response_limit_test.go b/backend/internal/service/upstream_response_limit_test.go new file mode 100644 index 00000000..b9e5cc6d --- /dev/null +++ b/backend/internal/service/upstream_response_limit_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "bytes" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestResolveUpstreamResponseReadLimit(t *testing.T) { + t.Run("use default when config missing", func(t *testing.T) { + require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil)) + }) + + t.Run("use configured value", func(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UpstreamResponseReadMaxBytes = 1234 + require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg)) + }) +} + +func TestReadUpstreamResponseBodyLimited(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2) + require.NoError(t, err) + require.Equal(t, []byte("ok"), body) + }) + + t.Run("exceeds limit", func(t *testing.T) { + body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3) + require.Nil(t, body) + require.Error(t, err) + require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + }) +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index f016fd49..9fd2d391 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -146,6 +146,15 @@ gateway: # Max request body size in bytes (default: 100MB) # 请求体最大字节数(默认 100MB) max_body_size: 104857600 + # Max bytes to read for non-stream upstream responses (default: 8MB) + # 非流式上游响应体读取上限(默认 8MB) + upstream_response_read_max_bytes: 8388608 + # Max bytes to read for proxy probe responses (default: 1MB) + # 代理探测响应体读取上限(默认 1MB) + proxy_probe_response_read_max_bytes: 1048576 + # Enable Gemini upstream response header debug logs (default: false) + # 是否开启 Gemini 上游响应头调试日志(默认 false) + gemini_debug_response_headers: false # Sora max request body size in bytes (0=use max_body_size) # Sora 请求体最大字节数(0=使用 max_body_size) sora_max_body_size: 268435456 diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 7b847c1b..b831c9ff 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -39,16 +39,6 @@ watch( { immediate: true } ) -watch( - () => appStore.siteName, - (newName) => { - if (newName) { - document.title = `${newName} - AI API Gateway` - } - }, - { immediate: true } -) - // Watch for authentication state and manage subscription data watch( () => authStore.isAuthenticated, diff --git a/frontend/src/__tests__/integration/data-import.spec.ts b/frontend/src/__tests__/integration/data-import.spec.ts index 1fe870ab..bc9de148 100644 --- a/frontend/src/__tests__/integration/data-import.spec.ts +++ b/frontend/src/__tests__/integration/data-import.spec.ts @@ -58,12 +58,16 @@ describe('ImportDataModal', () => { const input = wrapper.find('input[type="file"]') const file = new File(['invalid json'], 'data.json', { type: 'application/json' }) + Object.defineProperty(file, 'text', { + value: () => Promise.resolve('invalid json') + }) Object.defineProperty(input.element, 'files', { value: [file] }) await input.trigger('change') await wrapper.find('form').trigger('submit') + await Promise.resolve() expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed') }) diff --git a/frontend/src/__tests__/integration/proxy-data-import.spec.ts b/frontend/src/__tests__/integration/proxy-data-import.spec.ts index f0433898..21bf3a63 100644 --- a/frontend/src/__tests__/integration/proxy-data-import.spec.ts +++ b/frontend/src/__tests__/integration/proxy-data-import.spec.ts @@ -58,12 +58,16 @@ describe('Proxy ImportDataModal', () => { const input = wrapper.find('input[type="file"]') const file = new File(['invalid json'], 'data.json', { type: 'application/json' }) + Object.defineProperty(file, 'text', { + value: () => Promise.resolve('invalid json') + }) Object.defineProperty(input.element, 'files', { value: [file] }) await input.trigger('change') await wrapper.find('form').trigger('submit') + await Promise.resolve() expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed') }) diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 4cb1a6f2..65f2090c 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -164,10 +164,10 @@ export async function getUsage(id: number): Promise { /** * Clear account rate limit status * @param id - Account ID - * @returns Success confirmation + * @returns Updated account */ -export async function clearRateLimit(id: number): Promise<{ message: string }> { - const { data } = await apiClient.post<{ message: string }>( +export async function clearRateLimit(id: number): Promise { + const { data } = await apiClient.post( `/admin/accounts/${id}/clear-rate-limit` ) return data diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 838df569..18d2e968 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -209,7 +209,7 @@
('whitelist') const allowedModels = ref([]) const modelMappings = ref([]) +const getModelMappingKey = createStableObjectKeyResolver('bulk-model-mapping') const selectedErrorCodes = ref([]) const customErrorCodeInput = ref(null) const interceptWarmupRequests = ref(false) diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 0047592f..66a1d98e 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -714,7 +714,7 @@
@@ -966,7 +966,7 @@
@@ -2097,6 +2097,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' +import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' // Type for exposed OAuthAuthorizationFlow component @@ -2227,6 +2228,9 @@ const antigravityModelMappings = ref([]) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) +const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') +const getAntigravityModelMappingKey = createStableObjectKeyResolver('create-antigravity-model-mapping') +const getTempUnschedRuleKey = createStableObjectKeyResolver('create-temp-unsched-rule') const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one') const geminiAIStudioOAuthEnabled = ref(false) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 8986a350..0b6d00c9 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -169,7 +169,7 @@
@@ -542,7 +542,7 @@
@@ -1093,6 +1093,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' +import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { getPresetMappingsByPlatform, commonErrorCodes, @@ -1110,7 +1111,7 @@ interface Props { const props = defineProps() const emit = defineEmits<{ close: [] - updated: [] + updated: [account: Account] }>() const { t } = useI18n() @@ -1158,6 +1159,9 @@ const antigravityWhitelistModels = ref([]) const antigravityModelMappings = ref([]) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) +const getModelMappingKey = createStableObjectKeyResolver('edit-model-mapping') +const getAntigravityModelMappingKey = createStableObjectKeyResolver('edit-antigravity-model-mapping') +const getTempUnschedRuleKey = createStableObjectKeyResolver('edit-temp-unsched-rule') // Mixed channel warning dialog state const showMixedChannelWarning = ref(false) @@ -1845,9 +1849,9 @@ const handleSubmit = async () => { updatePayload.extra = newExtra } - await adminAPI.accounts.update(props.account.id, updatePayload) + const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload) appStore.showSuccess(t('admin.accounts.accountUpdated')) - emit('updated') + emit('updated', updatedAccount) handleClose() } catch (error: any) { // Handle 409 mixed_channel_warning - show confirmation dialog @@ -1875,9 +1879,9 @@ const handleMixedChannelConfirm = async () => { pendingUpdatePayload.value.confirm_mixed_channel_risk = true submitting.value = true try { - await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value) + const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value) appStore.showSuccess(t('admin.accounts.accountUpdated')) - emit('updated') + emit('updated', updatedAccount) handleClose() } catch (error: any) { appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) diff --git a/frontend/src/components/admin/account/ImportDataModal.vue b/frontend/src/components/admin/account/ImportDataModal.vue index 0d6de420..6c120be3 100644 --- a/frontend/src/components/admin/account/ImportDataModal.vue +++ b/frontend/src/components/admin/account/ImportDataModal.vue @@ -143,6 +143,24 @@ const handleClose = () => { emit('close') } +const readFileAsText = async (sourceFile: File): Promise => { + if (typeof sourceFile.text === 'function') { + return sourceFile.text() + } + + if (typeof sourceFile.arrayBuffer === 'function') { + const buffer = await sourceFile.arrayBuffer() + return new TextDecoder().decode(buffer) + } + + return await new Promise((resolve, reject) => { + const reader = new FileReader() + reader.onload = () => resolve(String(reader.result ?? '')) + reader.onerror = () => reject(reader.error || new Error('Failed to read file')) + reader.readAsText(sourceFile) + }) +} + const handleImport = async () => { if (!file.value) { appStore.showError(t('admin.accounts.dataImportSelectFile')) @@ -151,7 +169,7 @@ const handleImport = async () => { importing.value = true try { - const text = await file.value.text() + const text = await readFileAsText(file.value) const dataPayload = JSON.parse(text) const res = await adminAPI.accounts.importData({ diff --git a/frontend/src/components/admin/account/ReAuthAccountModal.vue b/frontend/src/components/admin/account/ReAuthAccountModal.vue index 8133e029..eeb3f288 100644 --- a/frontend/src/components/admin/account/ReAuthAccountModal.vue +++ b/frontend/src/components/admin/account/ReAuthAccountModal.vue @@ -216,7 +216,7 @@ interface Props { const props = defineProps() const emit = defineEmits<{ close: [] - reauthorized: [] + reauthorized: [account: Account] }>() const appStore = useAppStore() @@ -370,10 +370,10 @@ const handleExchangeCode = async () => { }) // Clear error status after successful re-authorization - await adminAPI.accounts.clearError(props.account.id) + const updatedAccount = await adminAPI.accounts.clearError(props.account.id) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) - emit('reauthorized') + emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') @@ -404,9 +404,9 @@ const handleExchangeCode = async () => { type: 'oauth', credentials }) - await adminAPI.accounts.clearError(props.account.id) + const updatedAccount = await adminAPI.accounts.clearError(props.account.id) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) - emit('reauthorized') + emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') @@ -436,9 +436,9 @@ const handleExchangeCode = async () => { type: 'oauth', credentials }) - await adminAPI.accounts.clearError(props.account.id) + const updatedAccount = await adminAPI.accounts.clearError(props.account.id) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) - emit('reauthorized') + emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') @@ -475,10 +475,10 @@ const handleExchangeCode = async () => { }) // Clear error status after successful re-authorization - await adminAPI.accounts.clearError(props.account.id) + const updatedAccount = await adminAPI.accounts.clearError(props.account.id) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) - emit('reauthorized') + emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') @@ -518,10 +518,10 @@ const handleCookieAuth = async (sessionKey: string) => { }) // Clear error status after successful re-authorization - await adminAPI.accounts.clearError(props.account.id) + const updatedAccount = await adminAPI.accounts.clearError(props.account.id) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) - emit('reauthorized') + emit('reauthorized', updatedAccount) handleClose() } catch (error: any) { claudeOAuth.error.value = diff --git a/frontend/src/components/admin/proxy/ImportDataModal.vue b/frontend/src/components/admin/proxy/ImportDataModal.vue index 6999ecc1..1ff71551 100644 --- a/frontend/src/components/admin/proxy/ImportDataModal.vue +++ b/frontend/src/components/admin/proxy/ImportDataModal.vue @@ -143,6 +143,24 @@ const handleClose = () => { emit('close') } +const readFileAsText = async (sourceFile: File): Promise => { + if (typeof sourceFile.text === 'function') { + return sourceFile.text() + } + + if (typeof sourceFile.arrayBuffer === 'function') { + const buffer = await sourceFile.arrayBuffer() + return new TextDecoder().decode(buffer) + } + + return await new Promise((resolve, reject) => { + const reader = new FileReader() + reader.onload = () => resolve(String(reader.result ?? '')) + reader.onerror = () => reject(reader.error || new Error('Failed to read file')) + reader.readAsText(sourceFile) + }) +} + const handleImport = async () => { if (!file.value) { appStore.showError(t('admin.proxies.dataImportSelectFile')) @@ -151,7 +169,7 @@ const handleImport = async () => { importing.value = true try { - const text = await file.value.text() + const text = await readFileAsText(file.value) const dataPayload = JSON.parse(text) const res = await adminAPI.proxies.importData({ data: dataPayload }) diff --git a/frontend/src/components/common/DataTable.vue b/frontend/src/components/common/DataTable.vue index c1e4333d..43755301 100644 --- a/frontend/src/components/common/DataTable.vue +++ b/frontend/src/components/common/DataTable.vue @@ -3,7 +3,7 @@ + + diff --git a/frontend/src/components/user/UserAttributesConfigModal.vue b/frontend/src/components/user/UserAttributesConfigModal.vue index 11474a22..9aa41a47 100644 --- a/frontend/src/components/user/UserAttributesConfigModal.vue +++ b/frontend/src/components/user/UserAttributesConfigModal.vue @@ -143,7 +143,7 @@
-
+
(null) const deletingAttribute = ref(null) +const getOptionKey = createStableObjectKeyResolver('user-attr-option') const form = reactive({ key: '', @@ -315,7 +317,7 @@ const openEditModal = (attr: UserAttributeDefinition) => { form.placeholder = attr.placeholder || '' form.required = attr.required form.enabled = attr.enabled - form.options = attr.options ? [...attr.options] : [] + form.options = attr.options ? attr.options.map((opt) => ({ ...opt })) : [] showEditModal.value = true } diff --git a/frontend/src/components/user/profile/TotpDisableDialog.vue b/frontend/src/components/user/profile/TotpDisableDialog.vue index daca4067..cd93764c 100644 --- a/frontend/src/components/user/profile/TotpDisableDialog.vue +++ b/frontend/src/components/user/profile/TotpDisableDialog.vue @@ -88,7 +88,7 @@ diff --git a/frontend/src/components/user/profile/TotpSetupModal.vue b/frontend/src/components/user/profile/TotpSetupModal.vue index 3d9b79ec..b544e75b 100644 --- a/frontend/src/components/user/profile/TotpSetupModal.vue +++ b/frontend/src/components/user/profile/TotpSetupModal.vue @@ -175,7 +175,7 @@ diff --git a/frontend/src/components/user/profile/__tests__/totp-timer-cleanup.spec.ts b/frontend/src/components/user/profile/__tests__/totp-timer-cleanup.spec.ts new file mode 100644 index 00000000..0259f902 --- /dev/null +++ b/frontend/src/components/user/profile/__tests__/totp-timer-cleanup.spec.ts @@ -0,0 +1,108 @@ +import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import TotpSetupModal from '@/components/user/profile/TotpSetupModal.vue' +import TotpDisableDialog from '@/components/user/profile/TotpDisableDialog.vue' + +const mocks = vi.hoisted(() => ({ + showSuccess: vi.fn(), + showError: vi.fn(), + getVerificationMethod: vi.fn(), + sendVerifyCode: vi.fn() +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showSuccess: mocks.showSuccess, + showError: mocks.showError + }) +})) + +vi.mock('@/api', () => ({ + totpAPI: { + getVerificationMethod: mocks.getVerificationMethod, + sendVerifyCode: mocks.sendVerifyCode, + initiateSetup: vi.fn(), + enable: vi.fn(), + disable: vi.fn() + } +})) + +const flushPromises = async () => { + await Promise.resolve() + await Promise.resolve() +} + +describe('TOTP 弹窗定时器清理', () => { + let intervalSeed = 1000 + let setIntervalSpy: ReturnType + let clearIntervalSpy: ReturnType + + beforeEach(() => { + intervalSeed = 1000 + mocks.showSuccess.mockReset() + mocks.showError.mockReset() + mocks.getVerificationMethod.mockReset() + mocks.sendVerifyCode.mockReset() + + mocks.getVerificationMethod.mockResolvedValue({ method: 'email' }) + mocks.sendVerifyCode.mockResolvedValue({ success: true }) + + setIntervalSpy = vi.spyOn(window, 'setInterval').mockImplementation(((handler: TimerHandler) => { + void handler + intervalSeed += 1 + return intervalSeed as unknown as number + }) as typeof window.setInterval) + clearIntervalSpy = vi.spyOn(window, 'clearInterval') + }) + + afterEach(() => { + setIntervalSpy.mockRestore() + clearIntervalSpy.mockRestore() + }) + + it('TotpSetupModal 卸载时清理倒计时定时器', async () => { + const wrapper = mount(TotpSetupModal) + await flushPromises() + + const sendButton = wrapper + .findAll('button') + .find((button) => button.text().includes('profile.totp.sendCode')) + + expect(sendButton).toBeTruthy() + await sendButton!.trigger('click') + await flushPromises() + + expect(setIntervalSpy).toHaveBeenCalledTimes(1) + const timerId = setIntervalSpy.mock.results[0]?.value + + wrapper.unmount() + + expect(clearIntervalSpy).toHaveBeenCalledWith(timerId) + }) + + it('TotpDisableDialog 卸载时清理倒计时定时器', async () => { + const wrapper = mount(TotpDisableDialog) + await flushPromises() + + const sendButton = wrapper + .findAll('button') + .find((button) => button.text().includes('profile.totp.sendCode')) + + expect(sendButton).toBeTruthy() + await sendButton!.trigger('click') + await flushPromises() + + expect(setIntervalSpy).toHaveBeenCalledTimes(1) + const timerId = setIntervalSpy.mock.results[0]?.value + + wrapper.unmount() + + expect(clearIntervalSpy).toHaveBeenCalledWith(timerId) + }) +}) diff --git a/frontend/src/composables/__tests__/useKeyedDebouncedSearch.spec.ts b/frontend/src/composables/__tests__/useKeyedDebouncedSearch.spec.ts new file mode 100644 index 00000000..4866746a --- /dev/null +++ b/frontend/src/composables/__tests__/useKeyedDebouncedSearch.spec.ts @@ -0,0 +1,100 @@ +import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest' +import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch' + +const flushPromises = () => Promise.resolve() + +describe('useKeyedDebouncedSearch', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('为不同 key 独立防抖触发搜索', async () => { + const search = vi.fn().mockResolvedValue([]) + const onSuccess = vi.fn() + + const searcher = useKeyedDebouncedSearch({ + delay: 100, + search, + onSuccess + }) + + searcher.trigger('a', 'foo') + searcher.trigger('b', 'bar') + + expect(search).not.toHaveBeenCalled() + + vi.advanceTimersByTime(100) + await flushPromises() + + expect(search).toHaveBeenCalledTimes(2) + expect(search).toHaveBeenNthCalledWith( + 1, + 'foo', + expect.objectContaining({ key: 'a', signal: expect.any(AbortSignal) }) + ) + expect(search).toHaveBeenNthCalledWith( + 2, + 'bar', + expect.objectContaining({ key: 'b', signal: expect.any(AbortSignal) }) + ) + expect(onSuccess).toHaveBeenCalledTimes(2) + }) + + it('同 key 新请求会取消旧请求并忽略过期响应', async () => { + const resolves: Array<(value: string[]) => void> = [] + const search = vi.fn().mockImplementation( + () => new Promise((resolve) => { + resolves.push(resolve) + }) + ) + const onSuccess = vi.fn() + + const searcher = useKeyedDebouncedSearch({ + delay: 50, + search, + onSuccess + }) + + searcher.trigger('rule-1', 'first') + vi.advanceTimersByTime(50) + await flushPromises() + + searcher.trigger('rule-1', 'second') + vi.advanceTimersByTime(50) + await flushPromises() + + expect(search).toHaveBeenCalledTimes(2) + + resolves[1](['second']) + await flushPromises() + expect(onSuccess).toHaveBeenCalledTimes(1) + expect(onSuccess).toHaveBeenLastCalledWith('rule-1', ['second']) + + resolves[0](['first']) + await flushPromises() + expect(onSuccess).toHaveBeenCalledTimes(1) + }) + + it('clearKey 会取消未执行任务', () => { + const search = vi.fn().mockResolvedValue([]) + const onSuccess = vi.fn() + + const searcher = useKeyedDebouncedSearch({ + delay: 100, + search, + onSuccess + }) + + searcher.trigger('a', 'foo') + searcher.clearKey('a') + + vi.advanceTimersByTime(100) + + expect(search).not.toHaveBeenCalled() + expect(onSuccess).not.toHaveBeenCalled() + }) +}) diff --git a/frontend/src/composables/useKeyedDebouncedSearch.ts b/frontend/src/composables/useKeyedDebouncedSearch.ts new file mode 100644 index 00000000..81133c38 --- /dev/null +++ b/frontend/src/composables/useKeyedDebouncedSearch.ts @@ -0,0 +1,103 @@ +import { getCurrentInstance, onUnmounted } from 'vue' + +export interface KeyedDebouncedSearchContext { + key: string + signal: AbortSignal +} + +interface UseKeyedDebouncedSearchOptions { + delay?: number + search: (keyword: string, context: KeyedDebouncedSearchContext) => Promise + onSuccess: (key: string, result: T) => void + onError?: (key: string, error: unknown) => void +} + +/** + * 多实例隔离的防抖搜索:每个 key 有独立的防抖、请求取消与过期响应保护。 + */ +export function useKeyedDebouncedSearch(options: UseKeyedDebouncedSearchOptions) { + const delay = options.delay ?? 300 + const timers = new Map>() + const controllers = new Map() + const versions = new Map() + + const clearKey = (key: string) => { + const timer = timers.get(key) + if (timer) { + clearTimeout(timer) + timers.delete(key) + } + + const controller = controllers.get(key) + if (controller) { + controller.abort() + controllers.delete(key) + } + + versions.delete(key) + } + + const clearAll = () => { + const allKeys = new Set([ + ...timers.keys(), + ...controllers.keys(), + ...versions.keys() + ]) + + allKeys.forEach((key) => clearKey(key)) + } + + const trigger = (key: string, keyword: string) => { + const nextVersion = (versions.get(key) ?? 0) + 1 + versions.set(key, nextVersion) + + const existingTimer = timers.get(key) + if (existingTimer) { + clearTimeout(existingTimer) + timers.delete(key) + } + + const inFlight = controllers.get(key) + if (inFlight) { + inFlight.abort() + controllers.delete(key) + } + + const timer = setTimeout(async () => { + timers.delete(key) + + const controller = new AbortController() + controllers.set(key, controller) + const requestVersion = versions.get(key) + + try { + const result = await options.search(keyword, { key, signal: controller.signal }) + if (controller.signal.aborted) return + if (versions.get(key) !== requestVersion) return + options.onSuccess(key, result) + } catch (error) { + if (controller.signal.aborted) return + if (versions.get(key) !== requestVersion) return + options.onError?.(key, error) + } finally { + if (controllers.get(key) === controller) { + controllers.delete(key) + } + } + }, delay) + + timers.set(key, timer) + } + + if (getCurrentInstance()) { + onUnmounted(() => { + clearAll() + }) + } + + return { + trigger, + clearKey, + clearAll + } +} diff --git a/frontend/src/i18n/index.ts b/frontend/src/i18n/index.ts index 486fb3bc..00e34dc2 100644 --- a/frontend/src/i18n/index.ts +++ b/frontend/src/i18n/index.ts @@ -1,53 +1,83 @@ import { createI18n } from 'vue-i18n' -import en from './locales/en' -import zh from './locales/zh' + +type LocaleCode = 'en' | 'zh' + +type LocaleMessages = Record const LOCALE_KEY = 'sub2api_locale' +const DEFAULT_LOCALE: LocaleCode = 'en' -function getDefaultLocale(): string { - // Check localStorage first +const localeLoaders: Record Promise<{ default: LocaleMessages }>> = { + en: () => import('./locales/en'), + zh: () => import('./locales/zh') +} + +function isLocaleCode(value: string): value is LocaleCode { + return value === 'en' || value === 'zh' +} + +function getDefaultLocale(): LocaleCode { const saved = localStorage.getItem(LOCALE_KEY) - if (saved && ['en', 'zh'].includes(saved)) { + if (saved && isLocaleCode(saved)) { return saved } - // Check browser language const browserLang = navigator.language.toLowerCase() if (browserLang.startsWith('zh')) { return 'zh' } - return 'en' + return DEFAULT_LOCALE } export const i18n = createI18n({ legacy: false, locale: getDefaultLocale(), - fallbackLocale: 'en', - messages: { - en, - zh - }, + fallbackLocale: DEFAULT_LOCALE, + messages: {}, // 禁用 HTML 消息警告 - 引导步骤使用富文本内容(driver.js 支持 HTML) // 这些内容是内部定义的,不存在 XSS 风险 warnHtmlMessage: false }) -export function setLocale(locale: string) { - if (['en', 'zh'].includes(locale)) { - i18n.global.locale.value = locale as 'en' | 'zh' - localStorage.setItem(LOCALE_KEY, locale) - document.documentElement.setAttribute('lang', locale) +const loadedLocales = new Set() + +export async function loadLocaleMessages(locale: LocaleCode): Promise { + if (loadedLocales.has(locale)) { + return } + + const loader = localeLoaders[locale] + const module = await loader() + i18n.global.setLocaleMessage(locale, module.default) + loadedLocales.add(locale) } -export function getLocale(): string { - return i18n.global.locale.value +export async function initI18n(): Promise { + const current = getLocale() + await loadLocaleMessages(current) + document.documentElement.setAttribute('lang', current) +} + +export async function setLocale(locale: string): Promise { + if (!isLocaleCode(locale)) { + return + } + + await loadLocaleMessages(locale) + i18n.global.locale.value = locale + localStorage.setItem(LOCALE_KEY, locale) + document.documentElement.setAttribute('lang', locale) +} + +export function getLocale(): LocaleCode { + const current = i18n.global.locale.value + return isLocaleCode(current) ? current : DEFAULT_LOCALE } export const availableLocales = [ { code: 'en', name: 'English', flag: '🇺🇸' }, { code: 'zh', name: '中文', flag: '🇨🇳' } -] +] as const export default i18n diff --git a/frontend/src/main.ts b/frontend/src/main.ts index 11c0b1e8..23f9d297 100644 --- a/frontend/src/main.ts +++ b/frontend/src/main.ts @@ -2,28 +2,33 @@ import { createApp } from 'vue' import { createPinia } from 'pinia' import App from './App.vue' import router from './router' -import i18n from './i18n' +import i18n, { initI18n } from './i18n' +import { useAppStore } from '@/stores/app' import './style.css' -const app = createApp(App) -const pinia = createPinia() -app.use(pinia) +async function bootstrap() { + const app = createApp(App) + const pinia = createPinia() + app.use(pinia) -// Initialize settings from injected config BEFORE mounting (prevents flash) -// This must happen after pinia is installed but before router and i18n -import { useAppStore } from '@/stores/app' -const appStore = useAppStore() -appStore.initFromInjectedConfig() + // Initialize settings from injected config BEFORE mounting (prevents flash) + // This must happen after pinia is installed but before router and i18n + const appStore = useAppStore() + appStore.initFromInjectedConfig() -// Set document title immediately after config is loaded -if (appStore.siteName && appStore.siteName !== 'Sub2API') { - document.title = `${appStore.siteName} - AI API Gateway` + // Set document title immediately after config is loaded + if (appStore.siteName && appStore.siteName !== 'Sub2API') { + document.title = `${appStore.siteName} - AI API Gateway` + } + + await initI18n() + + app.use(router) + app.use(i18n) + + // 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染 + await router.isReady() + app.mount('#app') } -app.use(router) -app.use(i18n) - -// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染 -router.isReady().then(() => { - app.mount('#app') -}) +bootstrap() diff --git a/frontend/src/router/__tests__/title.spec.ts b/frontend/src/router/__tests__/title.spec.ts new file mode 100644 index 00000000..3a892837 --- /dev/null +++ b/frontend/src/router/__tests__/title.spec.ts @@ -0,0 +1,25 @@ +import { describe, expect, it } from 'vitest' +import { resolveDocumentTitle } from '@/router/title' + +describe('resolveDocumentTitle', () => { + it('路由存在标题时,使用“路由标题 - 站点名”格式', () => { + expect(resolveDocumentTitle('Usage Records', 'My Site')).toBe('Usage Records - My Site') + }) + + it('路由无标题时,回退到站点名', () => { + expect(resolveDocumentTitle(undefined, 'My Site')).toBe('My Site') + }) + + it('站点名为空时,回退默认站点名', () => { + expect(resolveDocumentTitle('Dashboard', '')).toBe('Dashboard - Sub2API') + expect(resolveDocumentTitle(undefined, ' ')).toBe('Sub2API') + }) + + it('站点名变更时仅影响后续路由标题计算', () => { + const before = resolveDocumentTitle('Admin Dashboard', 'Alpha') + const after = resolveDocumentTitle('Admin Dashboard', 'Beta') + + expect(before).toBe('Admin Dashboard - Alpha') + expect(after).toBe('Admin Dashboard - Beta') + }) +}) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 4bb46cee..1a67cac6 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -8,6 +8,7 @@ import { useAuthStore } from '@/stores/auth' import { useAppStore } from '@/stores/app' import { useNavigationLoadingState } from '@/composables/useNavigationLoading' import { useRoutePrefetch } from '@/composables/useRoutePrefetch' +import { resolveDocumentTitle } from './title' /** * Route definitions with lazy loading @@ -389,12 +390,7 @@ router.beforeEach((to, _from, next) => { // Set page title const appStore = useAppStore() - const siteName = appStore.siteName || 'Sub2API' - if (to.meta.title) { - document.title = `${to.meta.title} - ${siteName}` - } else { - document.title = siteName - } + document.title = resolveDocumentTitle(to.meta.title, appStore.siteName) // Check if route requires authentication const requiresAuth = to.meta.requiresAuth !== false // Default to true diff --git a/frontend/src/router/title.ts b/frontend/src/router/title.ts new file mode 100644 index 00000000..e0db24b0 --- /dev/null +++ b/frontend/src/router/title.ts @@ -0,0 +1,12 @@ +/** + * 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。 + */ +export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string { + const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'Sub2API' + + if (typeof routeTitle === 'string' && routeTitle.trim()) { + return `${routeTitle.trim()} - ${normalizedSiteName}` + } + + return normalizedSiteName +} diff --git a/frontend/src/utils/__tests__/stableObjectKey.spec.ts b/frontend/src/utils/__tests__/stableObjectKey.spec.ts new file mode 100644 index 00000000..5a6f99f4 --- /dev/null +++ b/frontend/src/utils/__tests__/stableObjectKey.spec.ts @@ -0,0 +1,37 @@ +import { describe, expect, it } from 'vitest' +import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' + +describe('createStableObjectKeyResolver', () => { + it('对同一对象返回稳定 key', () => { + const resolve = createStableObjectKeyResolver<{ value: string }>('rule') + const obj = { value: 'a' } + + const key1 = resolve(obj) + const key2 = resolve(obj) + + expect(key1).toBe(key2) + expect(key1.startsWith('rule-')).toBe(true) + }) + + it('不同对象返回不同 key', () => { + const resolve = createStableObjectKeyResolver<{ value: string }>('rule') + + const key1 = resolve({ value: 'a' }) + const key2 = resolve({ value: 'a' }) + + expect(key1).not.toBe(key2) + }) + + it('不同 resolver 互不影响', () => { + const resolveA = createStableObjectKeyResolver<{ id: number }>('a') + const resolveB = createStableObjectKeyResolver<{ id: number }>('b') + const obj = { id: 1 } + + const keyA = resolveA(obj) + const keyB = resolveB(obj) + + expect(keyA).not.toBe(keyB) + expect(keyA.startsWith('a-')).toBe(true) + expect(keyB.startsWith('b-')).toBe(true) + }) +}) diff --git a/frontend/src/utils/stableObjectKey.ts b/frontend/src/utils/stableObjectKey.ts new file mode 100644 index 00000000..a61414f0 --- /dev/null +++ b/frontend/src/utils/stableObjectKey.ts @@ -0,0 +1,19 @@ +let globalStableObjectKeySeed = 0 + +/** + * 为对象实例生成稳定 key(基于 WeakMap,不污染业务对象) + */ +export function createStableObjectKeyResolver(prefix = 'item') { + const keyMap = new WeakMap() + + return (item: T): string => { + const cached = keyMap.get(item) + if (cached) { + return cached + } + + const key = `${prefix}-${++globalStableObjectKeySeed}` + keyMap.set(item, key) + return key + } +} diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 456fc8d9..a146130e 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -239,8 +239,8 @@ - - + + @@ -694,6 +694,53 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => { } const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() } const handleDataImported = () => { showImportData.value = false; reload() } +const accountMatchesCurrentFilters = (account: Account) => { + if (params.platform && account.platform !== params.platform) return false + if (params.type && account.type !== params.type) return false + if (params.status) { + if (params.status === 'rate_limited') { + if (!account.rate_limit_reset_at) return false + const resetAt = new Date(account.rate_limit_reset_at).getTime() + if (!Number.isFinite(resetAt) || resetAt <= Date.now()) return false + } else if (account.status !== params.status) { + return false + } + } + const search = String(params.search || '').trim().toLowerCase() + if (search && !account.name.toLowerCase().includes(search)) return false + return true +} +const mergeRuntimeFields = (oldAccount: Account, updatedAccount: Account): Account => ({ + ...updatedAccount, + current_concurrency: updatedAccount.current_concurrency ?? oldAccount.current_concurrency, + current_window_cost: updatedAccount.current_window_cost ?? oldAccount.current_window_cost, + active_sessions: updatedAccount.active_sessions ?? oldAccount.active_sessions +}) +const patchAccountInList = (updatedAccount: Account) => { + const index = accounts.value.findIndex(account => account.id === updatedAccount.id) + if (index === -1) return + const mergedAccount = mergeRuntimeFields(accounts.value[index], updatedAccount) + if (!accountMatchesCurrentFilters(mergedAccount)) { + accounts.value = accounts.value.filter(account => account.id !== mergedAccount.id) + selIds.value = selIds.value.filter(id => id !== mergedAccount.id) + if (menu.acc?.id === mergedAccount.id) { + menu.show = false + menu.acc = null + } + return + } + const nextAccounts = [...accounts.value] + nextAccounts[index] = mergedAccount + accounts.value = nextAccounts + if (edAcc.value?.id === mergedAccount.id) edAcc.value = mergedAccount + if (reAuthAcc.value?.id === mergedAccount.id) reAuthAcc.value = mergedAccount + if (tempUnschedAcc.value?.id === mergedAccount.id) tempUnschedAcc.value = mergedAccount + if (deletingAcc.value?.id === mergedAccount.id) deletingAcc.value = mergedAccount + if (menu.acc?.id === mergedAccount.id) menu.acc = mergedAccount +} +const handleAccountUpdated = (updatedAccount: Account) => { + patchAccountInList(updatedAccount) +} const formatExportTimestamp = () => { const now = new Date() const pad2 = (value: number) => String(value).padStart(2, '0') @@ -743,9 +790,32 @@ const closeReAuthModal = () => { showReAuth.value = false; reAuthAcc.value = nul const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true } const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true } const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true } -const handleRefresh = async (a: Account) => { try { await adminAPI.accounts.refreshCredentials(a.id); load() } catch (error) { console.error('Failed to refresh credentials:', error) } } -const handleResetStatus = async (a: Account) => { try { await adminAPI.accounts.clearError(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to reset status:', error) } } -const handleClearRateLimit = async (a: Account) => { try { await adminAPI.accounts.clearRateLimit(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to clear rate limit:', error) } } +const handleRefresh = async (a: Account) => { + try { + const updated = await adminAPI.accounts.refreshCredentials(a.id) + patchAccountInList(updated) + } catch (error) { + console.error('Failed to refresh credentials:', error) + } +} +const handleResetStatus = async (a: Account) => { + try { + const updated = await adminAPI.accounts.clearError(a.id) + patchAccountInList(updated) + appStore.showSuccess(t('common.success')) + } catch (error) { + console.error('Failed to reset status:', error) + } +} +const handleClearRateLimit = async (a: Account) => { + try { + const updated = await adminAPI.accounts.clearRateLimit(a.id) + patchAccountInList(updated) + appStore.showSuccess(t('common.success')) + } catch (error) { + console.error('Failed to clear rate limit:', error) + } +} const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true } const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } } const handleToggleSchedulable = async (a: Account) => { @@ -762,7 +832,17 @@ const handleToggleSchedulable = async (a: Account) => { } } const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true } -const handleTempUnschedReset = async () => { if(!tempUnschedAcc.value) return; try { await adminAPI.accounts.clearError(tempUnschedAcc.value.id); showTempUnsched.value = false; tempUnschedAcc.value = null; load() } catch (error) { console.error('Failed to reset temp unscheduled:', error) } } +const handleTempUnschedReset = async () => { + if(!tempUnschedAcc.value) return + try { + const updated = await adminAPI.accounts.clearError(tempUnschedAcc.value.id) + showTempUnsched.value = false + tempUnschedAcc.value = null + patchAccountInList(updated) + } catch (error) { + console.error('Failed to reset temp unscheduled:', error) + } +} const formatExpiresAt = (value: number | null) => { if (!value) return '-' return formatDateTime( diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index c6d15e2d..4d6dccf6 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -759,8 +759,8 @@
@@ -786,7 +786,7 @@ {{ account.name }}