Merge branch 'test' into release

This commit is contained in:
yangjianbo
2026-02-14 12:07:19 +08:00
57 changed files with 1715 additions and 304 deletions

View File

@@ -406,6 +406,14 @@ gateway:
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
- `turnstile.required` 在 release 模式强制启用 Turnstile - `turnstile.required` 在 release 模式强制启用 Turnstile
**网关防御纵深建议(重点)**
- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。
- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。
- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。
- `/auth/register``/auth/login``/auth/login/2fa``/auth/send-verify-code` 已提供服务端兜底限流Redis 故障时 fail-close
- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
**⚠️ 安全警告HTTP URL 配置** **⚠️ 安全警告HTTP URL 配置**
`security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL例如用于开发或内网测试必须显式设置 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL例如用于开发或内网测试必须显式设置

View File

@@ -308,6 +308,12 @@ type GatewayConfig struct {
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制 // 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"` MaxBodySize int64 `mapstructure:"max_body_size"`
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
// 代理探测响应体读取上限(字节)
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
// ConnectionPoolIsolation: 上游连接池隔离策略proxy/account/account_proxy // ConnectionPoolIsolation: 上游连接池隔离策略proxy/account/account_proxy
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
@@ -1059,6 +1065,9 @@ func setDefaults() {
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
viper.SetDefault("gateway.sora_request_timeout_seconds", 180) viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
if c.Gateway.MaxBodySize <= 0 { if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive") return fmt.Errorf("gateway.max_body_size must be positive")
} }
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
}
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
if c.Gateway.SoraMaxBodySize < 0 { if c.Gateway.SoraMaxBodySize < 0 {
return fmt.Errorf("gateway.sora_max_body_size must be non-negative") return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
} }

View File

@@ -1106,7 +1106,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return return
} }
response.Success(c, gin.H{"message": "Rate limit cleared successfully"}) account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
} }
// GetTempUnschedulable handles getting temporary unschedulable status // GetTempUnschedulable handles getting temporary unschedulable status

View File

@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
continue continue
} }
// 错误响应已在Forward中处理这里只记录日志 wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return return
} }
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
continue continue
} }
// 错误响应已在Forward中处理这里只记录日志 wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Error("gateway.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return return
} }
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
h.errorResponse(c, status, errType, message) h.errorResponse(c, status, errType, message)
} }
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse 返回Claude API格式的错误响应 // errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{

View File

@@ -0,0 +1,49 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
assert.Equal(t, "error", parsed["type"])
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &GatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}

View File

@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
) )
continue continue
} }
// Error response already handled in Forward, just log wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) reqLog.Error("openai.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return return
} }
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
h.errorResponse(c, status, errType, message) h.errorResponse(c, status, errType, message)
} }
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
if c == nil || c.Writer == nil || c.Writer.Written() {
return false
}
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
return true
}
// errorResponse returns OpenAI API format error response // errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{ c.JSON(status, gin.H{

View File

@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
assert.Equal(t, "test error", errorObj["message"]) assert.Equal(t, "test error", errorObj["message"])
} }
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.True(t, wrote)
require.Equal(t, http.StatusBadGateway, w.Code)
var parsed map[string]any
err := json.Unmarshal(w.Body.Bytes(), &parsed)
require.NoError(t, err)
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errorObj["type"])
assert.Equal(t, "Upstream request failed", errorObj["message"])
}
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.String(http.StatusTeapot, "already written")
h := &OpenAIGatewayHandler{}
wrote := h.ensureForwardErrorResponse(c, false)
require.False(t, wrote)
require.Equal(t, http.StatusTeapot, w.Code)
assert.Equal(t, "already written", w.Body.String())
}
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 // TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) { func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct { tests := []struct {

View File

@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
return normalizeIP(c.ClientIP()) return normalizeIP(c.ClientIP())
} }
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
// 适用于 ACL / 风控等安全敏感场景。
func GetTrustedClientIP(c *gin.Context) string {
if c == nil {
return ""
}
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。 // normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string { func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip) ip = strings.TrimSpace(ip)

View File

@@ -3,8 +3,10 @@
package ip package ip
import ( import (
"net/http/httptest"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
}) })
} }
} }
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
require.NoError(t, r.SetTrustedProxies(nil))
r.GET("/t", func(c *gin.Context) {
c.String(200, GetTrustedClientIP(c))
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
r.ServeHTTP(w, req)
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}

View File

@@ -1194,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
} }
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase. // Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" { if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400") clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
} }
if filter.StartTime != nil && !filter.StartTime.IsZero() { if filter.StartTime != nil && !filter.StartTime.IsZero() {
@@ -1208,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
} }
if p := strings.TrimSpace(filter.Platform); p != "" { if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p) args = append(args, p)
clauses = append(clauses, "platform = $"+itoa(len(args))) clauses = append(clauses, "e.platform = $"+itoa(len(args)))
} }
if filter.GroupID != nil && *filter.GroupID > 0 { if filter.GroupID != nil && *filter.GroupID > 0 {
args = append(args, *filter.GroupID) args = append(args, *filter.GroupID)
clauses = append(clauses, "group_id = $"+itoa(len(args))) clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
} }
if filter.AccountID != nil && *filter.AccountID > 0 { if filter.AccountID != nil && *filter.AccountID > 0 {
args = append(args, *filter.AccountID) args = append(args, *filter.AccountID)
clauses = append(clauses, "account_id = $"+itoa(len(args))) clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
} }
if phase := phaseFilter; phase != "" { if phase := phaseFilter; phase != "" {
args = append(args, phase) args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args))) clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
} }
if filter != nil { if filter != nil {
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" { if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
args = append(args, owner) args = append(args, owner)
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args))) clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
} }
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" { if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
args = append(args, source) args = append(args, source)
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args))) clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
} }
} }
if resolvedFilter != nil { if resolvedFilter != nil {
args = append(args, *resolvedFilter) args = append(args, *resolvedFilter)
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args))) clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
} }
// View filter: errors vs excluded vs all. // View filter: errors vs excluded vs all.
@@ -1246,46 +1246,46 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
} }
switch view { switch view {
case "", "errors": case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false") clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
case "excluded": case "excluded":
clauses = append(clauses, "COALESCE(is_business_limited,false) = true") clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
case "all": case "all":
// no-op // no-op
default: default:
// treat unknown as default 'errors' // treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false") clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
} }
if len(filter.StatusCodes) > 0 { if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes)) args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")") clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
} else if filter.StatusCodesOther { } else if filter.StatusCodesOther {
// "Other" means: status codes not in the common list. // "Other" means: status codes not in the common list.
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529} known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
args = append(args, pq.Array(known)) args = append(args, pq.Array(known))
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))") clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
} }
// Exact correlation keys (preferred for request↔upstream linkage). // Exact correlation keys (preferred for request↔upstream linkage).
if rid := strings.TrimSpace(filter.RequestID); rid != "" { if rid := strings.TrimSpace(filter.RequestID); rid != "" {
args = append(args, rid) args = append(args, rid)
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args))) clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
} }
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" { if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
args = append(args, crid) args = append(args, crid)
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args))) clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
} }
if q := strings.TrimSpace(filter.Query); q != "" { if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%" like := "%" + q + "%"
args = append(args, like) args = append(args, like)
n := itoa(len(args)) n := itoa(len(args))
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")") clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
} }
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" { if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
like := "%" + userQuery + "%" like := "%" + userQuery + "%"
args = append(args, like) args = append(args, like)
n := itoa(len(args)) n := itoa(len(args))
clauses = append(clauses, "u.email ILIKE $"+n) clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
} }
return "WHERE " + strings.Join(clauses, " AND "), args return "WHERE " + strings.Join(clauses, " AND "), args

View File

@@ -0,0 +1,48 @@
package repository
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
filter := &service.OpsErrorLogFilter{
Query: "ACCESS_DENIED",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "e.request_id ILIKE $") {
t.Fatalf("where should include qualified request_id condition: %s", where)
}
if !strings.Contains(where, "e.client_request_id ILIKE $") {
t.Fatalf("where should include qualified client_request_id condition: %s", where)
}
if !strings.Contains(where, "e.error_message ILIKE $") {
t.Fatalf("where should include qualified error_message condition: %s", where)
}
}
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
filter := &service.OpsErrorLogFilter{
UserQuery: "admin@",
}
where, args := buildOpsErrorLogsWhere(filter)
if where == "" {
t.Fatalf("where should not be empty")
}
if len(args) != 1 {
t.Fatalf("args len = %d, want 1", len(args))
}
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
t.Fatalf("where should include EXISTS user email condition: %s", where)
}
}

View File

@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false insecure := false
allowPrivate := false allowPrivate := false
validateResolvedIP := true validateResolvedIP := true
maxResponseBytes := defaultProxyProbeResponseMaxBytes
if cfg != nil { if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled validateResolvedIP = cfg.Security.URLAllowlist.Enabled
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
}
} }
if insecure { if insecure {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecureSkipVerify: insecure, insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate, allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP, validateResolvedIP: validateResolvedIP,
maxResponseBytes: maxResponseBytes,
} }
} }
const ( const (
defaultProxyProbeTimeout = 30 * time.Second defaultProxyProbeTimeout = 30 * time.Second
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
) )
// probeURLs 按优先级排列的探测 URL 列表 // probeURLs 按优先级排列的探测 URL 列表
@@ -52,6 +58,7 @@ type proxyProbeService struct {
insecureSkipVerify bool insecureSkipVerify bool
allowPrivateHosts bool allowPrivateHosts bool
validateResolvedIP bool validateResolvedIP bool
maxResponseBytes int64
} }
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
} }
body, err := io.ReadAll(resp.Body) maxResponseBytes := s.maxResponseBytes
if maxResponseBytes <= 0 {
maxResponseBytes = defaultProxyProbeResponseMaxBytes
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
if err != nil { if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
} }
if int64(len(body)) > maxResponseBytes {
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
}
switch parser { switch parser {
case "ip-api": case "ip-api":

View File

@@ -51,6 +51,9 @@ func ProvideRouter(
if err := r.SetTrustedProxies(nil); err != nil { if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err) log.Printf("Failed to disable trusted proxies: %v", err)
} }
if cfg.Server.Mode == "release" {
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
}
} }
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)

View File

@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查 IP 限制(白名单/黑名单) // 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c) clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed { if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")

View File

@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code) require.Equal(t, http.StatusOK, w.Code)
} }
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
IPWhitelist: []string{"1.2.3.4"},
}
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
require.NoError(t, router.SetTrustedProxies(nil))
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("x-api-key", apiKey.Key)
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusForbidden, w.Code)
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))

View File

@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
// 公开接口 // 公开接口
auth := v1.Group("/auth") auth := v1.Group("/auth")
{ {
auth.POST("/register", h.Auth.Register) // 注册/登录/2FA/验证码发送均属于高风险入口增加服务端兜底限流Redis 故障时 fail-close
auth.POST("/login", h.Auth.Login) auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
auth.POST("/login/2fa", h.Auth.Login2FA) FailureMode: middleware.RateLimitFailClose,
auth.POST("/send-verify-code", h.Auth.SendVerifyCode) }), h.Auth.Register)
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.Login)
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.Login2FA)
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.SendVerifyCode)
// Token刷新接口添加速率限制每分钟最多 30 次Redis 故障时 fail-close // Token刷新接口添加速率限制每分钟最多 30 次Redis 故障时 fail-close
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,

View File

@@ -0,0 +1,111 @@
//go:build integration
package routes
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const authRouteRedisImageTag = "redis:8.4-alpine"
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
ctx := context.Background()
rdb := startAuthRouteRedis(t, ctx)
router := newAuthRoutesTestRouter(rdb)
const path = "/api/v1/auth/register"
for i := 1; i <= 6; i++ {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "198.51.100.10:23456"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if i <= 5 {
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
continue
}
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
require.Contains(t, w.Body.String(), "rate limit exceeded")
}
}
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
ensureAuthRouteDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
redisHost, err := redisContainer.Host(ctx)
require.NoError(t, err)
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}
func ensureAuthRouteDockerAvailable(t *testing.T) {
t.Helper()
if authRouteDockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过认证限流集成测试")
}
func authRouteDockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func authRouteUserHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}

View File

@@ -0,0 +1,67 @@
package routes
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
RegisterAuthRoutes(
v1,
&handler.Handlers{
Auth: &handler.AuthHandler{},
Setting: &handler.SettingHandler{},
},
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
c.Next()
}),
redisClient,
)
return router
}
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
router := newAuthRoutesTestRouter(rdb)
paths := []string{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
}
for _, path := range paths {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "203.0.113.10:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
}
}

View File

@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 不需要重试(成功或不可重试的错误),跳出循环 // 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers用于检测 rate limit 信息) // DEBUG: 输出响应 headers用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 { if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header { for k, v := range resp.Header {
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
body, err := io.ReadAll(resp.Body) maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil { if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err return nil, err
} }
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 读取响应体 // 读取响应体
respBody, err := io.ReadAll(resp.Body) maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
_ = resp.Body.Close() _ = resp.Body.Close()
if err != nil { if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
return err
}
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err return err
} }
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
resp = retryResp resp = retryResp
respBody, err = io.ReadAll(resp.Body) respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
_ = resp.Body.Close() _ = resp.Body.Close()
if err != nil { if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
return err
}
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err return err
} }

View File

@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
} }
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header { for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
}
} }
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
} }
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body) maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil { if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err return nil, err
} }
var parsed map[string]any
if isOAuth { if isOAuth {
unwrappedBody, uwErr := unwrapGeminiResponse(respBody) unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
if uwErr == nil { if uwErr == nil {
respBody = unwrappedBody respBody = unwrappedBody
} }
_ = json.Unmarshal(respBody, &parsed)
} else {
_ = json.Unmarshal(respBody, &parsed)
} }
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
} }
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header { for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
}
} }
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
} }
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
if s.cfg != nil { if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)

View File

@@ -3,10 +3,15 @@ package service
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
} }
} }
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
svc := &GeminiMessagesCompatService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
GeminiDebugResponseHeaders: false,
},
},
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-RateLimit-Limit": []string{"60"},
},
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
}
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
require.NoError(t, err)
require.NotNil(t, usage)
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{ claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001", "model": "claude-haiku-4-5-20251001",

View File

@@ -313,7 +313,6 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco
} }
log := logger.FromContext(ctx).With(fields...) log := logger.FromContext(ctx).With(fields...)
if result.Matched { if result.Matched {
log.Warn("OpenAI codex_cli_only 允许官方客户端请求")
return return
} }
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
@@ -1277,6 +1276,29 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
startTime time.Time, startTime time.Time,
) (*OpenAIForwardResult, error) { ) (*OpenAIForwardResult, error) {
if account != nil && account.Type == AccountTypeOAuth { if account != nil && account.Type == AccountTypeOAuth {
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: http.StatusForbidden,
Passthrough: true,
Kind: "request_error",
Message: rejectMsg,
Detail: rejectReason,
})
logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body)
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"type": "forbidden_error",
"message": rejectMsg,
},
})
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
}
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1396,6 +1418,37 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}, nil }, nil
} }
func logOpenAIPassthroughInstructionsRejected(
ctx context.Context,
c *gin.Context,
account *Account,
reqModel string,
rejectReason string,
body []byte,
) {
if ctx == nil {
ctx = context.Background()
}
accountID := int64(0)
accountName := ""
accountType := ""
if account != nil {
accountID = account.ID
accountName = strings.TrimSpace(account.Name)
accountType = strings.TrimSpace(string(account.Type))
}
fields := []zap.Field{
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", accountID),
zap.String("account_name", accountName),
zap.String("account_type", accountType),
zap.String("request_model", strings.TrimSpace(reqModel)),
zap.String("reject_reason", strings.TrimSpace(rejectReason)),
}
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截Codex 请求缺少有效 instructions")
}
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
ctx context.Context, ctx context.Context,
c *gin.Context, c *gin.Context,
@@ -1688,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
resp *http.Response, resp *http.Response,
c *gin.Context, c *gin.Context,
) (*OpenAIUsage, error) { ) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body) maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil { if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err return nil, err
} }
@@ -2318,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
} }
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body) maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil { if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err return nil, err
} }
@@ -2877,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
return normalized, changed, nil return normalized, changed, nil
} }
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
model := strings.ToLower(strings.TrimSpace(reqModel))
if !strings.Contains(model, "codex") {
return ""
}
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() {
return "instructions_missing"
}
if instructions.Type != gjson.String {
return "instructions_not_string"
}
if strings.TrimSpace(instructions.String()) == "" {
return "instructions_empty"
}
return ""
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" { if reasoningEffort == "" {

View File

@@ -103,7 +103,7 @@ func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) {
}) })
} }
func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) { func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) {
logSink, restore := captureStructuredLog(t) logSink, restore := captureStructuredLog(t)
defer restore() defer restore()
@@ -119,7 +119,7 @@ func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
Reason: CodexClientRestrictionReasonNotMatchedUA, Reason: CodexClientRestrictionReasonNotMatchedUA,
}, nil) }, nil)
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求")) require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求")) require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求"))
} }
@@ -131,7 +131,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec) c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "curl/8.0") c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("OpenAI-Beta", "assistants=v2") c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
@@ -143,7 +143,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
Reason: CodexClientRestrictionReasonNotMatchedUA, Reason: CodexClientRestrictionReasonNotMatchedUA,
}, body) }, body)
require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0")) require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2")) require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123"))) require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))

View File

@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
c.Request.Header.Set("Proxy-Authorization", "Basic abc") c.Request.Header.Set("Proxy-Authorization", "Basic abc")
c.Request.Header.Set("X-Test", "keep") c.Request.Header.Set("X-Test", "keep")
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`) originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
upstreamSSE := strings.Join([]string{ upstreamSSE := strings.Join([]string{
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`, `data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
@@ -211,6 +211,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致store=false + stream=true。 // 1) 透传 OAuth 请求体与旧链路关键行为保持一致store=false + stream=true。
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String()))
// 其余关键字段保持原值。 // 其余关键字段保持原值。
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String()) require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
@@ -235,6 +236,59 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
require.NotContains(t, body, "\"name\":\"edit\"") require.NotContains(t, body, "\"name\":\"edit\"")
} }
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
// Codex 模型且缺少 instructions应在本地直接 403 拒绝,不触达上游。
originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
},
}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
httpUpstream: upstream,
}
account := &Account{
ID: 123,
Name: "acc",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
Extra: map[string]any{"openai_passthrough": true},
Status: StatusActive,
Schedulable: true,
RateMultiplier: f64p(1),
}
result, err := svc.Forward(context.Background(), c, account, originalBody)
require.Error(t, err)
require.Nil(t, result)
require.Equal(t, http.StatusForbidden, rec.Code)
require.Contains(t, rec.Body.String(), "requires a non-empty instructions field")
require.Nil(t, upstream.lastReq)
require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截Codex 请求缺少有效 instructions"))
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing"))
}
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) { func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -0,0 +1,38 @@
package service
import (
"errors"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
return cfg.Gateway.UpstreamResponseReadMaxBytes
}
return defaultUpstreamResponseReadMaxBytes
}
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
if reader == nil {
return nil, errors.New("response body is nil")
}
if maxBytes <= 0 {
maxBytes = defaultUpstreamResponseReadMaxBytes
}
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
if err != nil {
return nil, err
}
if int64(len(body)) > maxBytes {
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
}
return body, nil
}

View File

@@ -0,0 +1,37 @@
package service
import (
"bytes"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
t.Run("use default when config missing", func(t *testing.T) {
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
})
t.Run("use configured value", func(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
})
}
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
t.Run("within limit", func(t *testing.T) {
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
require.NoError(t, err)
require.Equal(t, []byte("ok"), body)
})
t.Run("exceeds limit", func(t *testing.T) {
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
require.Nil(t, body)
require.Error(t, err)
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
})
}

View File

@@ -146,6 +146,15 @@ gateway:
# Max request body size in bytes (default: 100MB) # Max request body size in bytes (default: 100MB)
# 请求体最大字节数(默认 100MB # 请求体最大字节数(默认 100MB
max_body_size: 104857600 max_body_size: 104857600
# Max bytes to read for non-stream upstream responses (default: 8MB)
# 非流式上游响应体读取上限(默认 8MB
upstream_response_read_max_bytes: 8388608
# Max bytes to read for proxy probe responses (default: 1MB)
# 代理探测响应体读取上限(默认 1MB
proxy_probe_response_read_max_bytes: 1048576
# Enable Gemini upstream response header debug logs (default: false)
# 是否开启 Gemini 上游响应头调试日志(默认 false
gemini_debug_response_headers: false
# Sora max request body size in bytes (0=use max_body_size) # Sora max request body size in bytes (0=use max_body_size)
# Sora 请求体最大字节数0=使用 max_body_size # Sora 请求体最大字节数0=使用 max_body_size
sora_max_body_size: 268435456 sora_max_body_size: 268435456

View File

@@ -39,16 +39,6 @@ watch(
{ immediate: true } { immediate: true }
) )
watch(
() => appStore.siteName,
(newName) => {
if (newName) {
document.title = `${newName} - AI API Gateway`
}
},
{ immediate: true }
)
// Watch for authentication state and manage subscription data // Watch for authentication state and manage subscription data
watch( watch(
() => authStore.isAuthenticated, () => authStore.isAuthenticated,

View File

@@ -58,12 +58,16 @@ describe('ImportDataModal', () => {
const input = wrapper.find('input[type="file"]') const input = wrapper.find('input[type="file"]')
const file = new File(['invalid json'], 'data.json', { type: 'application/json' }) const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
Object.defineProperty(file, 'text', {
value: () => Promise.resolve('invalid json')
})
Object.defineProperty(input.element, 'files', { Object.defineProperty(input.element, 'files', {
value: [file] value: [file]
}) })
await input.trigger('change') await input.trigger('change')
await wrapper.find('form').trigger('submit') await wrapper.find('form').trigger('submit')
await Promise.resolve()
expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed') expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed')
}) })

View File

@@ -58,12 +58,16 @@ describe('Proxy ImportDataModal', () => {
const input = wrapper.find('input[type="file"]') const input = wrapper.find('input[type="file"]')
const file = new File(['invalid json'], 'data.json', { type: 'application/json' }) const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
Object.defineProperty(file, 'text', {
value: () => Promise.resolve('invalid json')
})
Object.defineProperty(input.element, 'files', { Object.defineProperty(input.element, 'files', {
value: [file] value: [file]
}) })
await input.trigger('change') await input.trigger('change')
await wrapper.find('form').trigger('submit') await wrapper.find('form').trigger('submit')
await Promise.resolve()
expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed') expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed')
}) })

View File

@@ -164,10 +164,10 @@ export async function getUsage(id: number): Promise<AccountUsageInfo> {
/** /**
* Clear account rate limit status * Clear account rate limit status
* @param id - Account ID * @param id - Account ID
* @returns Success confirmation * @returns Updated account
*/ */
export async function clearRateLimit(id: number): Promise<{ message: string }> { export async function clearRateLimit(id: number): Promise<Account> {
const { data } = await apiClient.post<{ message: string }>( const { data } = await apiClient.post<Account>(
`/admin/accounts/${id}/clear-rate-limit` `/admin/accounts/${id}/clear-rate-limit`
) )
return data return data

View File

@@ -209,7 +209,7 @@
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2"> <div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div <div
v-for="(mapping, index) in modelMappings" v-for="(mapping, index) in modelMappings"
:key="index" :key="getModelMappingKey(mapping)"
class="flex items-center gap-2" class="flex items-center gap-2"
> >
<input <input
@@ -654,6 +654,7 @@ import Select from '@/components/common/Select.vue'
import ProxySelector from '@/components/common/ProxySelector.vue' import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
interface Props { interface Props {
show: boolean show: boolean
@@ -695,6 +696,7 @@ const baseUrl = ref('')
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref<string[]>([]) const allowedModels = ref<string[]>([])
const modelMappings = ref<ModelMapping[]>([]) const modelMappings = ref<ModelMapping[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
const selectedErrorCodes = ref<number[]>([]) const selectedErrorCodes = ref<number[]>([])
const customErrorCodeInput = ref<number | null>(null) const customErrorCodeInput = ref<number | null>(null)
const interceptWarmupRequests = ref(false) const interceptWarmupRequests = ref(false)

View File

@@ -714,7 +714,7 @@
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2"> <div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
<div <div
v-for="(mapping, index) in antigravityModelMappings" v-for="(mapping, index) in antigravityModelMappings"
:key="index" :key="getAntigravityModelMappingKey(mapping)"
class="space-y-1" class="space-y-1"
> >
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
@@ -966,7 +966,7 @@
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2"> <div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div <div
v-for="(mapping, index) in modelMappings" v-for="(mapping, index) in modelMappings"
:key="index" :key="getModelMappingKey(mapping)"
class="flex items-center gap-2" class="flex items-center gap-2"
> >
<input <input
@@ -1225,7 +1225,7 @@
<div v-if="tempUnschedRules.length > 0" class="space-y-3"> <div v-if="tempUnschedRules.length > 0" class="space-y-3">
<div <div
v-for="(rule, index) in tempUnschedRules" v-for="(rule, index) in tempUnschedRules"
:key="index" :key="getTempUnschedRuleKey(rule)"
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600" class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
> >
<div class="mb-2 flex items-center justify-between"> <div class="mb-2 flex items-center justify-between">
@@ -2097,6 +2097,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
// Type for exposed OAuthAuthorizationFlow component // Type for exposed OAuthAuthorizationFlow component
@@ -2227,6 +2228,9 @@ const antigravityModelMappings = ref<ModelMapping[]>([])
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
const tempUnschedEnabled = ref(false) const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([]) const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('create-temp-unsched-rule')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one') const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false) const geminiAIStudioOAuthEnabled = ref(false)

View File

@@ -169,7 +169,7 @@
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2"> <div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
<div <div
v-for="(mapping, index) in modelMappings" v-for="(mapping, index) in modelMappings"
:key="index" :key="getModelMappingKey(mapping)"
class="flex items-center gap-2" class="flex items-center gap-2"
> >
<input <input
@@ -417,7 +417,7 @@
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2"> <div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
<div <div
v-for="(mapping, index) in antigravityModelMappings" v-for="(mapping, index) in antigravityModelMappings"
:key="index" :key="getAntigravityModelMappingKey(mapping)"
class="space-y-1" class="space-y-1"
> >
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
@@ -542,7 +542,7 @@
<div v-if="tempUnschedRules.length > 0" class="space-y-3"> <div v-if="tempUnschedRules.length > 0" class="space-y-3">
<div <div
v-for="(rule, index) in tempUnschedRules" v-for="(rule, index) in tempUnschedRules"
:key="index" :key="getTempUnschedRuleKey(rule)"
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600" class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
> >
<div class="mb-2 flex items-center justify-between"> <div class="mb-2 flex items-center justify-between">
@@ -1093,6 +1093,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
import GroupSelector from '@/components/common/GroupSelector.vue' import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import { import {
getPresetMappingsByPlatform, getPresetMappingsByPlatform,
commonErrorCodes, commonErrorCodes,
@@ -1110,7 +1111,7 @@ interface Props {
const props = defineProps<Props>() const props = defineProps<Props>()
const emit = defineEmits<{ const emit = defineEmits<{
close: [] close: []
updated: [] updated: [account: Account]
}>() }>()
const { t } = useI18n() const { t } = useI18n()
@@ -1158,6 +1159,9 @@ const antigravityWhitelistModels = ref<string[]>([])
const antigravityModelMappings = ref<ModelMapping[]>([]) const antigravityModelMappings = ref<ModelMapping[]>([])
const tempUnschedEnabled = ref(false) const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref<TempUnschedRuleForm[]>([]) const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
// Mixed channel warning dialog state // Mixed channel warning dialog state
const showMixedChannelWarning = ref(false) const showMixedChannelWarning = ref(false)
@@ -1845,9 +1849,9 @@ const handleSubmit = async () => {
updatePayload.extra = newExtra updatePayload.extra = newExtra
} }
await adminAPI.accounts.update(props.account.id, updatePayload) const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
appStore.showSuccess(t('admin.accounts.accountUpdated')) appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated') emit('updated', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
// Handle 409 mixed_channel_warning - show confirmation dialog // Handle 409 mixed_channel_warning - show confirmation dialog
@@ -1875,9 +1879,9 @@ const handleMixedChannelConfirm = async () => {
pendingUpdatePayload.value.confirm_mixed_channel_risk = true pendingUpdatePayload.value.confirm_mixed_channel_risk = true
submitting.value = true submitting.value = true
try { try {
await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value) const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
appStore.showSuccess(t('admin.accounts.accountUpdated')) appStore.showSuccess(t('admin.accounts.accountUpdated'))
emit('updated') emit('updated', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))

View File

@@ -143,6 +143,24 @@ const handleClose = () => {
emit('close') emit('close')
} }
const readFileAsText = async (sourceFile: File): Promise<string> => {
if (typeof sourceFile.text === 'function') {
return sourceFile.text()
}
if (typeof sourceFile.arrayBuffer === 'function') {
const buffer = await sourceFile.arrayBuffer()
return new TextDecoder().decode(buffer)
}
return await new Promise<string>((resolve, reject) => {
const reader = new FileReader()
reader.onload = () => resolve(String(reader.result ?? ''))
reader.onerror = () => reject(reader.error || new Error('Failed to read file'))
reader.readAsText(sourceFile)
})
}
const handleImport = async () => { const handleImport = async () => {
if (!file.value) { if (!file.value) {
appStore.showError(t('admin.accounts.dataImportSelectFile')) appStore.showError(t('admin.accounts.dataImportSelectFile'))
@@ -151,7 +169,7 @@ const handleImport = async () => {
importing.value = true importing.value = true
try { try {
const text = await file.value.text() const text = await readFileAsText(file.value)
const dataPayload = JSON.parse(text) const dataPayload = JSON.parse(text)
const res = await adminAPI.accounts.importData({ const res = await adminAPI.accounts.importData({

View File

@@ -216,7 +216,7 @@ interface Props {
const props = defineProps<Props>() const props = defineProps<Props>()
const emit = defineEmits<{ const emit = defineEmits<{
close: [] close: []
reauthorized: [] reauthorized: [account: Account]
}>() }>()
const appStore = useAppStore() const appStore = useAppStore()
@@ -370,10 +370,10 @@ const handleExchangeCode = async () => {
}) })
// Clear error status after successful re-authorization // Clear error status after successful re-authorization
await adminAPI.accounts.clearError(props.account.id) const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized') emit('reauthorized', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
@@ -404,9 +404,9 @@ const handleExchangeCode = async () => {
type: 'oauth', type: 'oauth',
credentials credentials
}) })
await adminAPI.accounts.clearError(props.account.id) const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized') emit('reauthorized', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
@@ -436,9 +436,9 @@ const handleExchangeCode = async () => {
type: 'oauth', type: 'oauth',
credentials credentials
}) })
await adminAPI.accounts.clearError(props.account.id) const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized') emit('reauthorized', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
@@ -475,10 +475,10 @@ const handleExchangeCode = async () => {
}) })
// Clear error status after successful re-authorization // Clear error status after successful re-authorization
await adminAPI.accounts.clearError(props.account.id) const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized') emit('reauthorized', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
@@ -518,10 +518,10 @@ const handleCookieAuth = async (sessionKey: string) => {
}) })
// Clear error status after successful re-authorization // Clear error status after successful re-authorization
await adminAPI.accounts.clearError(props.account.id) const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess')) appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
emit('reauthorized') emit('reauthorized', updatedAccount)
handleClose() handleClose()
} catch (error: any) { } catch (error: any) {
claudeOAuth.error.value = claudeOAuth.error.value =

View File

@@ -143,6 +143,24 @@ const handleClose = () => {
emit('close') emit('close')
} }
const readFileAsText = async (sourceFile: File): Promise<string> => {
if (typeof sourceFile.text === 'function') {
return sourceFile.text()
}
if (typeof sourceFile.arrayBuffer === 'function') {
const buffer = await sourceFile.arrayBuffer()
return new TextDecoder().decode(buffer)
}
return await new Promise<string>((resolve, reject) => {
const reader = new FileReader()
reader.onload = () => resolve(String(reader.result ?? ''))
reader.onerror = () => reject(reader.error || new Error('Failed to read file'))
reader.readAsText(sourceFile)
})
}
const handleImport = async () => { const handleImport = async () => {
if (!file.value) { if (!file.value) {
appStore.showError(t('admin.proxies.dataImportSelectFile')) appStore.showError(t('admin.proxies.dataImportSelectFile'))
@@ -151,7 +169,7 @@ const handleImport = async () => {
importing.value = true importing.value = true
try { try {
const text = await file.value.text() const text = await readFileAsText(file.value)
const dataPayload = JSON.parse(text) const dataPayload = JSON.parse(text)
const res = await adminAPI.proxies.importData({ data: dataPayload }) const res = await adminAPI.proxies.importData({ data: dataPayload })

View File

@@ -3,7 +3,7 @@
<template v-if="loading"> <template v-if="loading">
<div v-for="i in 5" :key="i" class="rounded-lg border border-gray-200 bg-white p-4 dark:border-dark-700 dark:bg-dark-900"> <div v-for="i in 5" :key="i" class="rounded-lg border border-gray-200 bg-white p-4 dark:border-dark-700 dark:bg-dark-900">
<div class="space-y-3"> <div class="space-y-3">
<div v-for="column in columns.filter(c => c.key !== 'actions')" :key="column.key" class="flex justify-between"> <div v-for="column in dataColumns" :key="column.key" class="flex justify-between">
<div class="h-4 w-20 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div> <div class="h-4 w-20 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
<div class="h-4 w-32 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div> <div class="h-4 w-32 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
</div> </div>
@@ -39,7 +39,7 @@
> >
<div class="space-y-3"> <div class="space-y-3">
<div <div
v-for="column in columns.filter(c => c.key !== 'actions')" v-for="column in dataColumns"
:key="column.key" :key="column.key"
class="flex items-start justify-between gap-4" class="flex items-start justify-between gap-4"
> >
@@ -439,10 +439,15 @@ const resolveRowKey = (row: any, index: number) => {
return key ?? index return key ?? index
} }
const dataColumns = computed(() => props.columns.filter((column) => column.key !== 'actions'))
const columnsSignature = computed(() =>
props.columns.map((column) => `${column.key}:${column.sortable ? '1' : '0'}`).join('|')
)
// 数据/列变化时重新检查滚动状态 // 数据/列变化时重新检查滚动状态
// 注意:不能监听 actionsExpanded因为 checkActionsColumnWidth 会临时修改它,会导致无限循环 // 注意:不能监听 actionsExpanded因为 checkActionsColumnWidth 会临时修改它,会导致无限循环
watch( watch(
[() => props.data.length, () => props.columns], [() => props.data.length, columnsSignature],
async () => { async () => {
await nextTick() await nextTick()
checkScrollable() checkScrollable()
@@ -555,7 +560,7 @@ onMounted(() => {
}) })
watch( watch(
() => props.columns, columnsSignature,
() => { () => {
// If current sort key is no longer sortable/visible, fall back to default/persisted. // If current sort key is no longer sortable/visible, fall back to default/persisted.
const normalized = normalizeSortKey(sortKey.value) const normalized = normalizeSortKey(sortKey.value)
@@ -575,7 +580,7 @@ watch(
} }
} }
}, },
{ deep: true } { flush: 'post' }
) )
watch( watch(

View File

@@ -2,6 +2,7 @@
<div class="relative" ref="dropdownRef"> <div class="relative" ref="dropdownRef">
<button <button
@click="toggleDropdown" @click="toggleDropdown"
:disabled="switching"
class="flex items-center gap-1.5 rounded-lg px-2 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 dark:text-gray-300 dark:hover:bg-dark-700" class="flex items-center gap-1.5 rounded-lg px-2 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 dark:text-gray-300 dark:hover:bg-dark-700"
:title="currentLocale?.name" :title="currentLocale?.name"
> >
@@ -23,6 +24,7 @@
<button <button
v-for="locale in availableLocales" v-for="locale in availableLocales"
:key="locale.code" :key="locale.code"
:disabled="switching"
@click="selectLocale(locale.code)" @click="selectLocale(locale.code)"
class="flex w-full items-center gap-2 px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-100 dark:text-gray-200 dark:hover:bg-dark-700" class="flex w-full items-center gap-2 px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-100 dark:text-gray-200 dark:hover:bg-dark-700"
:class="{ :class="{
@@ -49,6 +51,7 @@ const { locale } = useI18n()
const isOpen = ref(false) const isOpen = ref(false)
const dropdownRef = ref<HTMLElement | null>(null) const dropdownRef = ref<HTMLElement | null>(null)
const switching = ref(false)
const currentLocaleCode = computed(() => locale.value) const currentLocaleCode = computed(() => locale.value)
const currentLocale = computed(() => availableLocales.find((l) => l.code === locale.value)) const currentLocale = computed(() => availableLocales.find((l) => l.code === locale.value))
@@ -57,9 +60,18 @@ function toggleDropdown() {
isOpen.value = !isOpen.value isOpen.value = !isOpen.value
} }
function selectLocale(code: string) { async function selectLocale(code: string) {
setLocale(code) if (switching.value || code === currentLocaleCode.value) {
isOpen.value = false isOpen.value = false
return
}
switching.value = true
try {
await setLocale(code)
isOpen.value = false
} finally {
switching.value = false
}
} }
function handleClickOutside(event: MouseEvent) { function handleClickOutside(event: MouseEvent) {

View File

@@ -84,8 +84,8 @@
<!-- Page numbers --> <!-- Page numbers -->
<button <button
v-for="pageNum in visiblePages" v-for="(pageNum, index) in visiblePages"
:key="pageNum" :key="`${pageNum}-${index}`"
@click="typeof pageNum === 'number' && goToPage(pageNum)" @click="typeof pageNum === 'number' && goToPage(pageNum)"
:disabled="typeof pageNum !== 'number'" :disabled="typeof pageNum !== 'number'"
:class="[ :class="[

View File

@@ -66,8 +66,8 @@
<!-- Progress bar --> <!-- Progress bar -->
<div v-if="toast.duration" class="h-1 bg-gray-100 dark:bg-dark-700"> <div v-if="toast.duration" class="h-1 bg-gray-100 dark:bg-dark-700">
<div <div
:class="['h-full transition-all', getProgressBarColor(toast.type)]" :class="['h-full toast-progress', getProgressBarColor(toast.type)]"
:style="{ width: `${getProgress(toast)}%` }" :style="{ animationDuration: `${toast.duration}ms` }"
></div> ></div>
</div> </div>
</div> </div>
@@ -77,7 +77,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { computed, onMounted, onUnmounted } from 'vue' import { computed } from 'vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
@@ -129,36 +129,25 @@ const getProgressBarColor = (type: string): string => {
return colors[type] || colors.info return colors[type] || colors.info
} }
const getProgress = (toast: any): number => {
if (!toast.duration || !toast.startTime) return 100
const elapsed = Date.now() - toast.startTime
const progress = Math.max(0, 100 - (elapsed / toast.duration) * 100)
return progress
}
const removeToast = (id: string) => { const removeToast = (id: string) => {
appStore.hideToast(id) appStore.hideToast(id)
} }
let intervalId: number | undefined
onMounted(() => {
// Check for expired toasts every 100ms
intervalId = window.setInterval(() => {
const now = Date.now()
toasts.value.forEach((toast) => {
if (toast.duration && toast.startTime) {
if (now - toast.startTime >= toast.duration) {
removeToast(toast.id)
}
}
})
}, 100)
})
onUnmounted(() => {
if (intervalId !== undefined) {
clearInterval(intervalId)
}
})
</script> </script>
<style scoped>
.toast-progress {
width: 100%;
animation-name: toast-progress-shrink;
animation-timing-function: linear;
animation-fill-mode: forwards;
}
@keyframes toast-progress-shrink {
from {
width: 100%;
}
to {
width: 0%;
}
}
</style>

View File

@@ -143,7 +143,7 @@
<!-- Options (for select/multi_select) --> <!-- Options (for select/multi_select) -->
<div v-if="form.type === 'select' || form.type === 'multi_select'" class="space-y-2"> <div v-if="form.type === 'select' || form.type === 'multi_select'" class="space-y-2">
<label class="input-label">{{ t('admin.users.attributes.options') }}</label> <label class="input-label">{{ t('admin.users.attributes.options') }}</label>
<div v-for="(option, index) in form.options" :key="index" class="flex items-center gap-2"> <div v-for="(option, index) in form.options" :key="getOptionKey(option)" class="flex items-center gap-2">
<input <input
v-model="option.value" v-model="option.value"
type="text" type="text"
@@ -246,6 +246,7 @@ import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import Select from '@/components/common/Select.vue' import Select from '@/components/common/Select.vue'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
@@ -270,6 +271,7 @@ const showEditModal = ref(false)
const showDeleteDialog = ref(false) const showDeleteDialog = ref(false)
const editingAttribute = ref<UserAttributeDefinition | null>(null) const editingAttribute = ref<UserAttributeDefinition | null>(null)
const deletingAttribute = ref<UserAttributeDefinition | null>(null) const deletingAttribute = ref<UserAttributeDefinition | null>(null)
const getOptionKey = createStableObjectKeyResolver<UserAttributeOption>('user-attr-option')
const form = reactive({ const form = reactive({
key: '', key: '',
@@ -315,7 +317,7 @@ const openEditModal = (attr: UserAttributeDefinition) => {
form.placeholder = attr.placeholder || '' form.placeholder = attr.placeholder || ''
form.required = attr.required form.required = attr.required
form.enabled = attr.enabled form.enabled = attr.enabled
form.options = attr.options ? [...attr.options] : [] form.options = attr.options ? attr.options.map((opt) => ({ ...opt })) : []
showEditModal.value = true showEditModal.value = true
} }

View File

@@ -88,7 +88,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, computed } from 'vue' import { ref, onMounted, onUnmounted, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { totpAPI } from '@/api' import { totpAPI } from '@/api'
@@ -107,6 +107,7 @@ const loading = ref(false)
const error = ref('') const error = ref('')
const sendingCode = ref(false) const sendingCode = ref(false)
const codeCooldown = ref(0) const codeCooldown = ref(0)
const cooldownTimer = ref<ReturnType<typeof setInterval> | null>(null)
const form = ref({ const form = ref({
emailCode: '', emailCode: '',
password: '' password: ''
@@ -139,10 +140,17 @@ const handleSendCode = async () => {
appStore.showSuccess(t('profile.totp.codeSent')) appStore.showSuccess(t('profile.totp.codeSent'))
// Start cooldown // Start cooldown
codeCooldown.value = 60 codeCooldown.value = 60
const timer = setInterval(() => { if (cooldownTimer.value) {
clearInterval(cooldownTimer.value)
cooldownTimer.value = null
}
cooldownTimer.value = setInterval(() => {
codeCooldown.value-- codeCooldown.value--
if (codeCooldown.value <= 0) { if (codeCooldown.value <= 0) {
clearInterval(timer) if (cooldownTimer.value) {
clearInterval(cooldownTimer.value)
cooldownTimer.value = null
}
} }
}, 1000) }, 1000)
} catch (err: any) { } catch (err: any) {
@@ -176,4 +184,11 @@ const handleDisable = async () => {
onMounted(() => { onMounted(() => {
loadVerificationMethod() loadVerificationMethod()
}) })
onUnmounted(() => {
if (cooldownTimer.value) {
clearInterval(cooldownTimer.value)
cooldownTimer.value = null
}
})
</script> </script>

View File

@@ -175,7 +175,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, nextTick, watch, computed } from 'vue' import { ref, onMounted, onUnmounted, nextTick, watch, computed } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { totpAPI } from '@/api' import { totpAPI } from '@/api'
@@ -198,6 +198,7 @@ const verifyForm = ref({ emailCode: '', password: '' })
const verifyError = ref('') const verifyError = ref('')
const sendingCode = ref(false) const sendingCode = ref(false)
const codeCooldown = ref(0) const codeCooldown = ref(0)
const cooldownTimer = ref<ReturnType<typeof setInterval> | null>(null)
const setupLoading = ref(false) const setupLoading = ref(false)
const setupData = ref<TotpSetupResponse | null>(null) const setupData = ref<TotpSetupResponse | null>(null)
@@ -338,10 +339,17 @@ const handleSendCode = async () => {
appStore.showSuccess(t('profile.totp.codeSent')) appStore.showSuccess(t('profile.totp.codeSent'))
// Start cooldown // Start cooldown
codeCooldown.value = 60 codeCooldown.value = 60
const timer = setInterval(() => { if (cooldownTimer.value) {
clearInterval(cooldownTimer.value)
cooldownTimer.value = null
}
cooldownTimer.value = setInterval(() => {
codeCooldown.value-- codeCooldown.value--
if (codeCooldown.value <= 0) { if (codeCooldown.value <= 0) {
clearInterval(timer) if (cooldownTimer.value) {
clearInterval(cooldownTimer.value)
cooldownTimer.value = null
}
} }
}, 1000) }, 1000)
} catch (err: any) { } catch (err: any) {
@@ -397,4 +405,11 @@ const handleVerify = async () => {
onMounted(() => { onMounted(() => {
loadVerificationMethod() loadVerificationMethod()
}) })
onUnmounted(() => {
if (cooldownTimer.value) {
clearInterval(cooldownTimer.value)
cooldownTimer.value = null
}
})
</script> </script>

View File

@@ -0,0 +1,108 @@
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'
import { mount } from '@vue/test-utils'
import TotpSetupModal from '@/components/user/profile/TotpSetupModal.vue'
import TotpDisableDialog from '@/components/user/profile/TotpDisableDialog.vue'
const mocks = vi.hoisted(() => ({
showSuccess: vi.fn(),
showError: vi.fn(),
getVerificationMethod: vi.fn(),
sendVerifyCode: vi.fn()
}))
vi.mock('vue-i18n', () => ({
useI18n: () => ({
t: (key: string) => key
})
}))
vi.mock('@/stores/app', () => ({
useAppStore: () => ({
showSuccess: mocks.showSuccess,
showError: mocks.showError
})
}))
vi.mock('@/api', () => ({
totpAPI: {
getVerificationMethod: mocks.getVerificationMethod,
sendVerifyCode: mocks.sendVerifyCode,
initiateSetup: vi.fn(),
enable: vi.fn(),
disable: vi.fn()
}
}))
const flushPromises = async () => {
await Promise.resolve()
await Promise.resolve()
}
describe('TOTP 弹窗定时器清理', () => {
let intervalSeed = 1000
let setIntervalSpy: ReturnType<typeof vi.spyOn>
let clearIntervalSpy: ReturnType<typeof vi.spyOn>
beforeEach(() => {
intervalSeed = 1000
mocks.showSuccess.mockReset()
mocks.showError.mockReset()
mocks.getVerificationMethod.mockReset()
mocks.sendVerifyCode.mockReset()
mocks.getVerificationMethod.mockResolvedValue({ method: 'email' })
mocks.sendVerifyCode.mockResolvedValue({ success: true })
setIntervalSpy = vi.spyOn(window, 'setInterval').mockImplementation(((handler: TimerHandler) => {
void handler
intervalSeed += 1
return intervalSeed as unknown as number
}) as typeof window.setInterval)
clearIntervalSpy = vi.spyOn(window, 'clearInterval')
})
afterEach(() => {
setIntervalSpy.mockRestore()
clearIntervalSpy.mockRestore()
})
it('TotpSetupModal 卸载时清理倒计时定时器', async () => {
const wrapper = mount(TotpSetupModal)
await flushPromises()
const sendButton = wrapper
.findAll('button')
.find((button) => button.text().includes('profile.totp.sendCode'))
expect(sendButton).toBeTruthy()
await sendButton!.trigger('click')
await flushPromises()
expect(setIntervalSpy).toHaveBeenCalledTimes(1)
const timerId = setIntervalSpy.mock.results[0]?.value
wrapper.unmount()
expect(clearIntervalSpy).toHaveBeenCalledWith(timerId)
})
it('TotpDisableDialog 卸载时清理倒计时定时器', async () => {
const wrapper = mount(TotpDisableDialog)
await flushPromises()
const sendButton = wrapper
.findAll('button')
.find((button) => button.text().includes('profile.totp.sendCode'))
expect(sendButton).toBeTruthy()
await sendButton!.trigger('click')
await flushPromises()
expect(setIntervalSpy).toHaveBeenCalledTimes(1)
const timerId = setIntervalSpy.mock.results[0]?.value
wrapper.unmount()
expect(clearIntervalSpy).toHaveBeenCalledWith(timerId)
})
})

View File

@@ -0,0 +1,100 @@
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
const flushPromises = () => Promise.resolve()
describe('useKeyedDebouncedSearch', () => {
beforeEach(() => {
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
it('为不同 key 独立防抖触发搜索', async () => {
const search = vi.fn().mockResolvedValue([])
const onSuccess = vi.fn()
const searcher = useKeyedDebouncedSearch<string[]>({
delay: 100,
search,
onSuccess
})
searcher.trigger('a', 'foo')
searcher.trigger('b', 'bar')
expect(search).not.toHaveBeenCalled()
vi.advanceTimersByTime(100)
await flushPromises()
expect(search).toHaveBeenCalledTimes(2)
expect(search).toHaveBeenNthCalledWith(
1,
'foo',
expect.objectContaining({ key: 'a', signal: expect.any(AbortSignal) })
)
expect(search).toHaveBeenNthCalledWith(
2,
'bar',
expect.objectContaining({ key: 'b', signal: expect.any(AbortSignal) })
)
expect(onSuccess).toHaveBeenCalledTimes(2)
})
it('同 key 新请求会取消旧请求并忽略过期响应', async () => {
const resolves: Array<(value: string[]) => void> = []
const search = vi.fn().mockImplementation(
() => new Promise<string[]>((resolve) => {
resolves.push(resolve)
})
)
const onSuccess = vi.fn()
const searcher = useKeyedDebouncedSearch<string[]>({
delay: 50,
search,
onSuccess
})
searcher.trigger('rule-1', 'first')
vi.advanceTimersByTime(50)
await flushPromises()
searcher.trigger('rule-1', 'second')
vi.advanceTimersByTime(50)
await flushPromises()
expect(search).toHaveBeenCalledTimes(2)
resolves[1](['second'])
await flushPromises()
expect(onSuccess).toHaveBeenCalledTimes(1)
expect(onSuccess).toHaveBeenLastCalledWith('rule-1', ['second'])
resolves[0](['first'])
await flushPromises()
expect(onSuccess).toHaveBeenCalledTimes(1)
})
it('clearKey 会取消未执行任务', () => {
const search = vi.fn().mockResolvedValue([])
const onSuccess = vi.fn()
const searcher = useKeyedDebouncedSearch<string[]>({
delay: 100,
search,
onSuccess
})
searcher.trigger('a', 'foo')
searcher.clearKey('a')
vi.advanceTimersByTime(100)
expect(search).not.toHaveBeenCalled()
expect(onSuccess).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,103 @@
import { getCurrentInstance, onUnmounted } from 'vue'
export interface KeyedDebouncedSearchContext {
key: string
signal: AbortSignal
}
interface UseKeyedDebouncedSearchOptions<T> {
delay?: number
search: (keyword: string, context: KeyedDebouncedSearchContext) => Promise<T>
onSuccess: (key: string, result: T) => void
onError?: (key: string, error: unknown) => void
}
/**
* 多实例隔离的防抖搜索:每个 key 有独立的防抖、请求取消与过期响应保护。
*/
export function useKeyedDebouncedSearch<T>(options: UseKeyedDebouncedSearchOptions<T>) {
const delay = options.delay ?? 300
const timers = new Map<string, ReturnType<typeof setTimeout>>()
const controllers = new Map<string, AbortController>()
const versions = new Map<string, number>()
const clearKey = (key: string) => {
const timer = timers.get(key)
if (timer) {
clearTimeout(timer)
timers.delete(key)
}
const controller = controllers.get(key)
if (controller) {
controller.abort()
controllers.delete(key)
}
versions.delete(key)
}
const clearAll = () => {
const allKeys = new Set<string>([
...timers.keys(),
...controllers.keys(),
...versions.keys()
])
allKeys.forEach((key) => clearKey(key))
}
const trigger = (key: string, keyword: string) => {
const nextVersion = (versions.get(key) ?? 0) + 1
versions.set(key, nextVersion)
const existingTimer = timers.get(key)
if (existingTimer) {
clearTimeout(existingTimer)
timers.delete(key)
}
const inFlight = controllers.get(key)
if (inFlight) {
inFlight.abort()
controllers.delete(key)
}
const timer = setTimeout(async () => {
timers.delete(key)
const controller = new AbortController()
controllers.set(key, controller)
const requestVersion = versions.get(key)
try {
const result = await options.search(keyword, { key, signal: controller.signal })
if (controller.signal.aborted) return
if (versions.get(key) !== requestVersion) return
options.onSuccess(key, result)
} catch (error) {
if (controller.signal.aborted) return
if (versions.get(key) !== requestVersion) return
options.onError?.(key, error)
} finally {
if (controllers.get(key) === controller) {
controllers.delete(key)
}
}
}, delay)
timers.set(key, timer)
}
if (getCurrentInstance()) {
onUnmounted(() => {
clearAll()
})
}
return {
trigger,
clearKey,
clearAll
}
}

View File

@@ -1,53 +1,83 @@
import { createI18n } from 'vue-i18n' import { createI18n } from 'vue-i18n'
import en from './locales/en'
import zh from './locales/zh' type LocaleCode = 'en' | 'zh'
type LocaleMessages = Record<string, any>
const LOCALE_KEY = 'sub2api_locale' const LOCALE_KEY = 'sub2api_locale'
const DEFAULT_LOCALE: LocaleCode = 'en'
function getDefaultLocale(): string { const localeLoaders: Record<LocaleCode, () => Promise<{ default: LocaleMessages }>> = {
// Check localStorage first en: () => import('./locales/en'),
zh: () => import('./locales/zh')
}
function isLocaleCode(value: string): value is LocaleCode {
return value === 'en' || value === 'zh'
}
function getDefaultLocale(): LocaleCode {
const saved = localStorage.getItem(LOCALE_KEY) const saved = localStorage.getItem(LOCALE_KEY)
if (saved && ['en', 'zh'].includes(saved)) { if (saved && isLocaleCode(saved)) {
return saved return saved
} }
// Check browser language
const browserLang = navigator.language.toLowerCase() const browserLang = navigator.language.toLowerCase()
if (browserLang.startsWith('zh')) { if (browserLang.startsWith('zh')) {
return 'zh' return 'zh'
} }
return 'en' return DEFAULT_LOCALE
} }
export const i18n = createI18n({ export const i18n = createI18n({
legacy: false, legacy: false,
locale: getDefaultLocale(), locale: getDefaultLocale(),
fallbackLocale: 'en', fallbackLocale: DEFAULT_LOCALE,
messages: { messages: {},
en,
zh
},
// 禁用 HTML 消息警告 - 引导步骤使用富文本内容driver.js 支持 HTML // 禁用 HTML 消息警告 - 引导步骤使用富文本内容driver.js 支持 HTML
// 这些内容是内部定义的,不存在 XSS 风险 // 这些内容是内部定义的,不存在 XSS 风险
warnHtmlMessage: false warnHtmlMessage: false
}) })
export function setLocale(locale: string) { const loadedLocales = new Set<LocaleCode>()
if (['en', 'zh'].includes(locale)) {
i18n.global.locale.value = locale as 'en' | 'zh' export async function loadLocaleMessages(locale: LocaleCode): Promise<void> {
localStorage.setItem(LOCALE_KEY, locale) if (loadedLocales.has(locale)) {
document.documentElement.setAttribute('lang', locale) return
} }
const loader = localeLoaders[locale]
const module = await loader()
i18n.global.setLocaleMessage(locale, module.default)
loadedLocales.add(locale)
} }
export function getLocale(): string { export async function initI18n(): Promise<void> {
return i18n.global.locale.value const current = getLocale()
await loadLocaleMessages(current)
document.documentElement.setAttribute('lang', current)
}
export async function setLocale(locale: string): Promise<void> {
if (!isLocaleCode(locale)) {
return
}
await loadLocaleMessages(locale)
i18n.global.locale.value = locale
localStorage.setItem(LOCALE_KEY, locale)
document.documentElement.setAttribute('lang', locale)
}
export function getLocale(): LocaleCode {
const current = i18n.global.locale.value
return isLocaleCode(current) ? current : DEFAULT_LOCALE
} }
export const availableLocales = [ export const availableLocales = [
{ code: 'en', name: 'English', flag: '🇺🇸' }, { code: 'en', name: 'English', flag: '🇺🇸' },
{ code: 'zh', name: '中文', flag: '🇨🇳' } { code: 'zh', name: '中文', flag: '🇨🇳' }
] ] as const
export default i18n export default i18n

View File

@@ -2,28 +2,33 @@ import { createApp } from 'vue'
import { createPinia } from 'pinia' import { createPinia } from 'pinia'
import App from './App.vue' import App from './App.vue'
import router from './router' import router from './router'
import i18n from './i18n' import i18n, { initI18n } from './i18n'
import { useAppStore } from '@/stores/app'
import './style.css' import './style.css'
const app = createApp(App) async function bootstrap() {
const pinia = createPinia() const app = createApp(App)
app.use(pinia) const pinia = createPinia()
app.use(pinia)
// Initialize settings from injected config BEFORE mounting (prevents flash) // Initialize settings from injected config BEFORE mounting (prevents flash)
// This must happen after pinia is installed but before router and i18n // This must happen after pinia is installed but before router and i18n
import { useAppStore } from '@/stores/app' const appStore = useAppStore()
const appStore = useAppStore() appStore.initFromInjectedConfig()
appStore.initFromInjectedConfig()
// Set document title immediately after config is loaded // Set document title immediately after config is loaded
if (appStore.siteName && appStore.siteName !== 'Sub2API') { if (appStore.siteName && appStore.siteName !== 'Sub2API') {
document.title = `${appStore.siteName} - AI API Gateway` document.title = `${appStore.siteName} - AI API Gateway`
}
await initI18n()
app.use(router)
app.use(i18n)
// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染
await router.isReady()
app.mount('#app')
} }
app.use(router) bootstrap()
app.use(i18n)
// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染
router.isReady().then(() => {
app.mount('#app')
})

View File

@@ -0,0 +1,25 @@
import { describe, expect, it } from 'vitest'
import { resolveDocumentTitle } from '@/router/title'
describe('resolveDocumentTitle', () => {
it('路由存在标题时,使用“路由标题 - 站点名”格式', () => {
expect(resolveDocumentTitle('Usage Records', 'My Site')).toBe('Usage Records - My Site')
})
it('路由无标题时,回退到站点名', () => {
expect(resolveDocumentTitle(undefined, 'My Site')).toBe('My Site')
})
it('站点名为空时,回退默认站点名', () => {
expect(resolveDocumentTitle('Dashboard', '')).toBe('Dashboard - Sub2API')
expect(resolveDocumentTitle(undefined, ' ')).toBe('Sub2API')
})
it('站点名变更时仅影响后续路由标题计算', () => {
const before = resolveDocumentTitle('Admin Dashboard', 'Alpha')
const after = resolveDocumentTitle('Admin Dashboard', 'Beta')
expect(before).toBe('Admin Dashboard - Alpha')
expect(after).toBe('Admin Dashboard - Beta')
})
})

View File

@@ -8,6 +8,7 @@ import { useAuthStore } from '@/stores/auth'
import { useAppStore } from '@/stores/app' import { useAppStore } from '@/stores/app'
import { useNavigationLoadingState } from '@/composables/useNavigationLoading' import { useNavigationLoadingState } from '@/composables/useNavigationLoading'
import { useRoutePrefetch } from '@/composables/useRoutePrefetch' import { useRoutePrefetch } from '@/composables/useRoutePrefetch'
import { resolveDocumentTitle } from './title'
/** /**
* Route definitions with lazy loading * Route definitions with lazy loading
@@ -389,12 +390,7 @@ router.beforeEach((to, _from, next) => {
// Set page title // Set page title
const appStore = useAppStore() const appStore = useAppStore()
const siteName = appStore.siteName || 'Sub2API' document.title = resolveDocumentTitle(to.meta.title, appStore.siteName)
if (to.meta.title) {
document.title = `${to.meta.title} - ${siteName}`
} else {
document.title = siteName
}
// Check if route requires authentication // Check if route requires authentication
const requiresAuth = to.meta.requiresAuth !== false // Default to true const requiresAuth = to.meta.requiresAuth !== false // Default to true

View File

@@ -0,0 +1,12 @@
/**
* 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。
*/
export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string {
const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'Sub2API'
if (typeof routeTitle === 'string' && routeTitle.trim()) {
return `${routeTitle.trim()} - ${normalizedSiteName}`
}
return normalizedSiteName
}

View File

@@ -0,0 +1,37 @@
import { describe, expect, it } from 'vitest'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
describe('createStableObjectKeyResolver', () => {
it('对同一对象返回稳定 key', () => {
const resolve = createStableObjectKeyResolver<{ value: string }>('rule')
const obj = { value: 'a' }
const key1 = resolve(obj)
const key2 = resolve(obj)
expect(key1).toBe(key2)
expect(key1.startsWith('rule-')).toBe(true)
})
it('不同对象返回不同 key', () => {
const resolve = createStableObjectKeyResolver<{ value: string }>('rule')
const key1 = resolve({ value: 'a' })
const key2 = resolve({ value: 'a' })
expect(key1).not.toBe(key2)
})
it('不同 resolver 互不影响', () => {
const resolveA = createStableObjectKeyResolver<{ id: number }>('a')
const resolveB = createStableObjectKeyResolver<{ id: number }>('b')
const obj = { id: 1 }
const keyA = resolveA(obj)
const keyB = resolveB(obj)
expect(keyA).not.toBe(keyB)
expect(keyA.startsWith('a-')).toBe(true)
expect(keyB.startsWith('b-')).toBe(true)
})
})

View File

@@ -0,0 +1,19 @@
let globalStableObjectKeySeed = 0
/**
* 为对象实例生成稳定 key基于 WeakMap不污染业务对象
*/
export function createStableObjectKeyResolver<T extends object>(prefix = 'item') {
const keyMap = new WeakMap<T, string>()
return (item: T): string => {
const cached = keyMap.get(item)
if (cached) {
return cached
}
const key = `${prefix}-${++globalStableObjectKeySeed}`
keyMap.set(item, key)
return key
}
}

View File

@@ -239,8 +239,8 @@
<template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /></template> <template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /></template>
</TablePageLayout> </TablePageLayout>
<CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" /> <CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" />
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="load" /> <EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="handleAccountUpdated" />
<ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="load" /> <ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="handleAccountUpdated" />
<AccountTestModal :show="showTest" :account="testingAcc" @close="closeTestModal" /> <AccountTestModal :show="showTest" :account="testingAcc" @close="closeTestModal" />
<AccountStatsModal :show="showStats" :account="statsAcc" @close="closeStatsModal" /> <AccountStatsModal :show="showStats" :account="statsAcc" @close="closeStatsModal" />
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" /> <AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
@@ -694,6 +694,53 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => {
} }
const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() } const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() }
const handleDataImported = () => { showImportData.value = false; reload() } const handleDataImported = () => { showImportData.value = false; reload() }
const accountMatchesCurrentFilters = (account: Account) => {
if (params.platform && account.platform !== params.platform) return false
if (params.type && account.type !== params.type) return false
if (params.status) {
if (params.status === 'rate_limited') {
if (!account.rate_limit_reset_at) return false
const resetAt = new Date(account.rate_limit_reset_at).getTime()
if (!Number.isFinite(resetAt) || resetAt <= Date.now()) return false
} else if (account.status !== params.status) {
return false
}
}
const search = String(params.search || '').trim().toLowerCase()
if (search && !account.name.toLowerCase().includes(search)) return false
return true
}
const mergeRuntimeFields = (oldAccount: Account, updatedAccount: Account): Account => ({
...updatedAccount,
current_concurrency: updatedAccount.current_concurrency ?? oldAccount.current_concurrency,
current_window_cost: updatedAccount.current_window_cost ?? oldAccount.current_window_cost,
active_sessions: updatedAccount.active_sessions ?? oldAccount.active_sessions
})
const patchAccountInList = (updatedAccount: Account) => {
const index = accounts.value.findIndex(account => account.id === updatedAccount.id)
if (index === -1) return
const mergedAccount = mergeRuntimeFields(accounts.value[index], updatedAccount)
if (!accountMatchesCurrentFilters(mergedAccount)) {
accounts.value = accounts.value.filter(account => account.id !== mergedAccount.id)
selIds.value = selIds.value.filter(id => id !== mergedAccount.id)
if (menu.acc?.id === mergedAccount.id) {
menu.show = false
menu.acc = null
}
return
}
const nextAccounts = [...accounts.value]
nextAccounts[index] = mergedAccount
accounts.value = nextAccounts
if (edAcc.value?.id === mergedAccount.id) edAcc.value = mergedAccount
if (reAuthAcc.value?.id === mergedAccount.id) reAuthAcc.value = mergedAccount
if (tempUnschedAcc.value?.id === mergedAccount.id) tempUnschedAcc.value = mergedAccount
if (deletingAcc.value?.id === mergedAccount.id) deletingAcc.value = mergedAccount
if (menu.acc?.id === mergedAccount.id) menu.acc = mergedAccount
}
const handleAccountUpdated = (updatedAccount: Account) => {
patchAccountInList(updatedAccount)
}
const formatExportTimestamp = () => { const formatExportTimestamp = () => {
const now = new Date() const now = new Date()
const pad2 = (value: number) => String(value).padStart(2, '0') const pad2 = (value: number) => String(value).padStart(2, '0')
@@ -743,9 +790,32 @@ const closeReAuthModal = () => { showReAuth.value = false; reAuthAcc.value = nul
const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true } const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true }
const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true } const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true }
const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true } const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true }
const handleRefresh = async (a: Account) => { try { await adminAPI.accounts.refreshCredentials(a.id); load() } catch (error) { console.error('Failed to refresh credentials:', error) } } const handleRefresh = async (a: Account) => {
const handleResetStatus = async (a: Account) => { try { await adminAPI.accounts.clearError(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to reset status:', error) } } try {
const handleClearRateLimit = async (a: Account) => { try { await adminAPI.accounts.clearRateLimit(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to clear rate limit:', error) } } const updated = await adminAPI.accounts.refreshCredentials(a.id)
patchAccountInList(updated)
} catch (error) {
console.error('Failed to refresh credentials:', error)
}
}
const handleResetStatus = async (a: Account) => {
try {
const updated = await adminAPI.accounts.clearError(a.id)
patchAccountInList(updated)
appStore.showSuccess(t('common.success'))
} catch (error) {
console.error('Failed to reset status:', error)
}
}
const handleClearRateLimit = async (a: Account) => {
try {
const updated = await adminAPI.accounts.clearRateLimit(a.id)
patchAccountInList(updated)
appStore.showSuccess(t('common.success'))
} catch (error) {
console.error('Failed to clear rate limit:', error)
}
}
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true } const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } } const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } }
const handleToggleSchedulable = async (a: Account) => { const handleToggleSchedulable = async (a: Account) => {
@@ -762,7 +832,17 @@ const handleToggleSchedulable = async (a: Account) => {
} }
} }
const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true } const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true }
const handleTempUnschedReset = async () => { if(!tempUnschedAcc.value) return; try { await adminAPI.accounts.clearError(tempUnschedAcc.value.id); showTempUnsched.value = false; tempUnschedAcc.value = null; load() } catch (error) { console.error('Failed to reset temp unscheduled:', error) } } const handleTempUnschedReset = async () => {
if(!tempUnschedAcc.value) return
try {
const updated = await adminAPI.accounts.clearError(tempUnschedAcc.value.id)
showTempUnsched.value = false
tempUnschedAcc.value = null
patchAccountInList(updated)
} catch (error) {
console.error('Failed to reset temp unscheduled:', error)
}
}
const formatExpiresAt = (value: number | null) => { const formatExpiresAt = (value: number | null) => {
if (!value) return '-' if (!value) return '-'
return formatDateTime( return formatDateTime(

View File

@@ -759,8 +759,8 @@
<!-- 路由规则列表仅在启用时显示 --> <!-- 路由规则列表仅在启用时显示 -->
<div v-if="createForm.model_routing_enabled" class="space-y-3"> <div v-if="createForm.model_routing_enabled" class="space-y-3">
<div <div
v-for="(rule, index) in createModelRoutingRules" v-for="rule in createModelRoutingRules"
:key="index" :key="getCreateRuleRenderKey(rule)"
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600" class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
> >
<div class="flex items-start gap-3"> <div class="flex items-start gap-3">
@@ -786,7 +786,7 @@
{{ account.name }} {{ account.name }}
<button <button
type="button" type="button"
@click="removeSelectedAccount(index, account.id, false)" @click="removeSelectedAccount(rule, account.id)"
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200" class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
> >
<Icon name="x" size="xs" /> <Icon name="x" size="xs" />
@@ -796,23 +796,23 @@
<!-- 账号搜索输入框 --> <!-- 账号搜索输入框 -->
<div class="relative account-search-container"> <div class="relative account-search-container">
<input <input
v-model="accountSearchKeyword[`create-${index}`]" v-model="accountSearchKeyword[getCreateRuleSearchKey(rule)]"
type="text" type="text"
class="input text-sm" class="input text-sm"
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')" :placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
@input="searchAccounts(`create-${index}`)" @input="searchAccountsByRule(rule)"
@focus="onAccountSearchFocus(index, false)" @focus="onAccountSearchFocus(rule)"
/> />
<!-- 搜索结果下拉框 --> <!-- 搜索结果下拉框 -->
<div <div
v-if="showAccountDropdown[`create-${index}`] && accountSearchResults[`create-${index}`]?.length > 0" v-if="showAccountDropdown[getCreateRuleSearchKey(rule)] && accountSearchResults[getCreateRuleSearchKey(rule)]?.length > 0"
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800" class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
> >
<button <button
v-for="account in accountSearchResults[`create-${index}`]" v-for="account in accountSearchResults[getCreateRuleSearchKey(rule)]"
:key="account.id" :key="account.id"
type="button" type="button"
@click="selectAccount(index, account, false)" @click="selectAccount(rule, account)"
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700" class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }" :class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
:disabled="rule.accounts.some(a => a.id === account.id)" :disabled="rule.accounts.some(a => a.id === account.id)"
@@ -827,7 +827,7 @@
</div> </div>
<button <button
type="button" type="button"
@click="removeCreateRoutingRule(index)" @click="removeCreateRoutingRule(rule)"
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors" class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
:title="t('admin.groups.modelRouting.removeRule')" :title="t('admin.groups.modelRouting.removeRule')"
> >
@@ -1439,8 +1439,8 @@
<!-- 路由规则列表仅在启用时显示 --> <!-- 路由规则列表仅在启用时显示 -->
<div v-if="editForm.model_routing_enabled" class="space-y-3"> <div v-if="editForm.model_routing_enabled" class="space-y-3">
<div <div
v-for="(rule, index) in editModelRoutingRules" v-for="rule in editModelRoutingRules"
:key="index" :key="getEditRuleRenderKey(rule)"
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600" class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
> >
<div class="flex items-start gap-3"> <div class="flex items-start gap-3">
@@ -1466,7 +1466,7 @@
{{ account.name }} {{ account.name }}
<button <button
type="button" type="button"
@click="removeSelectedAccount(index, account.id, true)" @click="removeSelectedAccount(rule, account.id, true)"
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200" class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
> >
<Icon name="x" size="xs" /> <Icon name="x" size="xs" />
@@ -1476,23 +1476,23 @@
<!-- 账号搜索输入框 --> <!-- 账号搜索输入框 -->
<div class="relative account-search-container"> <div class="relative account-search-container">
<input <input
v-model="accountSearchKeyword[`edit-${index}`]" v-model="accountSearchKeyword[getEditRuleSearchKey(rule)]"
type="text" type="text"
class="input text-sm" class="input text-sm"
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')" :placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
@input="searchAccounts(`edit-${index}`)" @input="searchAccountsByRule(rule, true)"
@focus="onAccountSearchFocus(index, true)" @focus="onAccountSearchFocus(rule, true)"
/> />
<!-- 搜索结果下拉框 --> <!-- 搜索结果下拉框 -->
<div <div
v-if="showAccountDropdown[`edit-${index}`] && accountSearchResults[`edit-${index}`]?.length > 0" v-if="showAccountDropdown[getEditRuleSearchKey(rule)] && accountSearchResults[getEditRuleSearchKey(rule)]?.length > 0"
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800" class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
> >
<button <button
v-for="account in accountSearchResults[`edit-${index}`]" v-for="account in accountSearchResults[getEditRuleSearchKey(rule)]"
:key="account.id" :key="account.id"
type="button" type="button"
@click="selectAccount(index, account, true)" @click="selectAccount(rule, account, true)"
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700" class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }" :class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
:disabled="rule.accounts.some(a => a.id === account.id)" :disabled="rule.accounts.some(a => a.id === account.id)"
@@ -1507,7 +1507,7 @@
</div> </div>
<button <button
type="button" type="button"
@click="removeEditRoutingRule(index)" @click="removeEditRoutingRule(rule)"
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors" class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
:title="t('admin.groups.modelRouting.removeRule')" :title="t('admin.groups.modelRouting.removeRule')"
> >
@@ -1687,6 +1687,8 @@ import Select from '@/components/common/Select.vue'
import PlatformIcon from '@/components/common/PlatformIcon.vue' import PlatformIcon from '@/components/common/PlatformIcon.vue'
import Icon from '@/components/icons/Icon.vue' import Icon from '@/components/icons/Icon.vue'
import { VueDraggable } from 'vue-draggable-plus' import { VueDraggable } from 'vue-draggable-plus'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
const { t } = useI18n() const { t } = useI18n()
const appStore = useAppStore() const appStore = useAppStore()
@@ -1911,33 +1913,70 @@ const createModelRoutingRules = ref<ModelRoutingRule[]>([])
// 编辑表单的模型路由规则 // 编辑表单的模型路由规则
const editModelRoutingRules = ref<ModelRoutingRule[]>([]) const editModelRoutingRules = ref<ModelRoutingRule[]>([])
// 账号搜索相关状态 // 规则对象稳定 key避免使用 index 导致状态错位)
const accountSearchKeyword = ref<Record<string, string>>({}) // 每个规则的搜索关键词 (key: "create-0" 或 "edit-0") const resolveCreateRuleKey = createStableObjectKeyResolver<ModelRoutingRule>('create-rule')
const accountSearchResults = ref<Record<string, SimpleAccount[]>>({}) // 每个规则的搜索结果 const resolveEditRuleKey = createStableObjectKeyResolver<ModelRoutingRule>('edit-rule')
const showAccountDropdown = ref<Record<string, boolean>>({}) // 每个规则的下拉框显示状态
let accountSearchTimeout: ReturnType<typeof setTimeout> | null = null
// 搜索账号(仅限 anthropic 平台) const getCreateRuleRenderKey = (rule: ModelRoutingRule) => resolveCreateRuleKey(rule)
const searchAccounts = async (key: string) => { const getEditRuleRenderKey = (rule: ModelRoutingRule) => resolveEditRuleKey(rule)
if (accountSearchTimeout) clearTimeout(accountSearchTimeout)
accountSearchTimeout = setTimeout(async () => { const getCreateRuleSearchKey = (rule: ModelRoutingRule) => `create-${resolveCreateRuleKey(rule)}`
const keyword = accountSearchKeyword.value[key] || '' const getEditRuleSearchKey = (rule: ModelRoutingRule) => `edit-${resolveEditRuleKey(rule)}`
try {
const res = await adminAPI.accounts.list(1, 20, { const getRuleSearchKey = (rule: ModelRoutingRule, isEdit: boolean = false) => {
return isEdit ? getEditRuleSearchKey(rule) : getCreateRuleSearchKey(rule)
}
// 账号搜索相关状态
const accountSearchKeyword = ref<Record<string, string>>({})
const accountSearchResults = ref<Record<string, SimpleAccount[]>>({})
const showAccountDropdown = ref<Record<string, boolean>>({})
const clearAccountSearchStateByKey = (key: string) => {
delete accountSearchKeyword.value[key]
delete accountSearchResults.value[key]
delete showAccountDropdown.value[key]
}
const clearAllAccountSearchState = () => {
accountSearchKeyword.value = {}
accountSearchResults.value = {}
showAccountDropdown.value = {}
}
const accountSearchRunner = useKeyedDebouncedSearch<SimpleAccount[]>({
delay: 300,
search: async (keyword, { signal }) => {
const res = await adminAPI.accounts.list(
1,
20,
{
search: keyword, search: keyword,
platform: 'anthropic' platform: 'anthropic'
}) },
accountSearchResults.value[key] = res.items.map((a) => ({ id: a.id, name: a.name })) { signal }
} catch { )
accountSearchResults.value[key] = [] return res.items.map((account) => ({ id: account.id, name: account.name }))
} },
}, 300) onSuccess: (key, result) => {
accountSearchResults.value[key] = result
},
onError: (key) => {
accountSearchResults.value[key] = []
}
})
// 搜索账号(仅限 anthropic 平台)
const searchAccounts = (key: string) => {
accountSearchRunner.trigger(key, accountSearchKeyword.value[key] || '')
}
const searchAccountsByRule = (rule: ModelRoutingRule, isEdit: boolean = false) => {
searchAccounts(getRuleSearchKey(rule, isEdit))
} }
// 选择账号 // 选择账号
const selectAccount = (ruleIndex: number, account: SimpleAccount, isEdit: boolean = false) => { const selectAccount = (rule: ModelRoutingRule, account: SimpleAccount, isEdit: boolean = false) => {
const rules = isEdit ? editModelRoutingRules.value : createModelRoutingRules.value
const rule = rules[ruleIndex]
if (!rule) return if (!rule) return
// 检查是否已选择 // 检查是否已选择
@@ -1946,15 +1985,13 @@ const selectAccount = (ruleIndex: number, account: SimpleAccount, isEdit: boolea
} }
// 清空搜索 // 清空搜索
const key = `${isEdit ? 'edit' : 'create'}-${ruleIndex}` const key = getRuleSearchKey(rule, isEdit)
accountSearchKeyword.value[key] = '' accountSearchKeyword.value[key] = ''
showAccountDropdown.value[key] = false showAccountDropdown.value[key] = false
} }
// 移除已选账号 // 移除已选账号
const removeSelectedAccount = (ruleIndex: number, accountId: number, isEdit: boolean = false) => { const removeSelectedAccount = (rule: ModelRoutingRule, accountId: number, _isEdit: boolean = false) => {
const rules = isEdit ? editModelRoutingRules.value : createModelRoutingRules.value
const rule = rules[ruleIndex]
if (!rule) return if (!rule) return
rule.accounts = rule.accounts.filter(a => a.id !== accountId) rule.accounts = rule.accounts.filter(a => a.id !== accountId)
@@ -1981,8 +2018,8 @@ const toggleEditScope = (scope: string) => {
} }
// 处理账号搜索输入框聚焦 // 处理账号搜索输入框聚焦
const onAccountSearchFocus = (ruleIndex: number, isEdit: boolean = false) => { const onAccountSearchFocus = (rule: ModelRoutingRule, isEdit: boolean = false) => {
const key = `${isEdit ? 'edit' : 'create'}-${ruleIndex}` const key = getRuleSearchKey(rule, isEdit)
showAccountDropdown.value[key] = true showAccountDropdown.value[key] = true
// 如果没有搜索结果,触发一次搜索 // 如果没有搜索结果,触发一次搜索
if (!accountSearchResults.value[key]?.length) { if (!accountSearchResults.value[key]?.length) {
@@ -1996,13 +2033,14 @@ const addCreateRoutingRule = () => {
} }
// 删除创建表单的路由规则 // 删除创建表单的路由规则
const removeCreateRoutingRule = (index: number) => { const removeCreateRoutingRule = (rule: ModelRoutingRule) => {
const index = createModelRoutingRules.value.indexOf(rule)
if (index === -1) return
const key = getCreateRuleSearchKey(rule)
accountSearchRunner.clearKey(key)
clearAccountSearchStateByKey(key)
createModelRoutingRules.value.splice(index, 1) createModelRoutingRules.value.splice(index, 1)
// 清理相关的搜索状态
const key = `create-${index}`
delete accountSearchKeyword.value[key]
delete accountSearchResults.value[key]
delete showAccountDropdown.value[key]
} }
// 添加编辑表单的路由规则 // 添加编辑表单的路由规则
@@ -2011,13 +2049,14 @@ const addEditRoutingRule = () => {
} }
// 删除编辑表单的路由规则 // 删除编辑表单的路由规则
const removeEditRoutingRule = (index: number) => { const removeEditRoutingRule = (rule: ModelRoutingRule) => {
const index = editModelRoutingRules.value.indexOf(rule)
if (index === -1) return
const key = getEditRuleSearchKey(rule)
accountSearchRunner.clearKey(key)
clearAccountSearchStateByKey(key)
editModelRoutingRules.value.splice(index, 1) editModelRoutingRules.value.splice(index, 1)
// 清理相关的搜索状态
const key = `edit-${index}`
delete accountSearchKeyword.value[key]
delete accountSearchResults.value[key]
delete showAccountDropdown.value[key]
} }
// 将 UI 格式的路由规则转换为 API 格式 // 将 UI 格式的路由规则转换为 API 格式
@@ -2161,6 +2200,10 @@ const handlePageSizeChange = (pageSize: number) => {
const closeCreateModal = () => { const closeCreateModal = () => {
showCreateModal.value = false showCreateModal.value = false
createModelRoutingRules.value.forEach((rule) => {
accountSearchRunner.clearKey(getCreateRuleSearchKey(rule))
})
clearAllAccountSearchState()
createForm.name = '' createForm.name = ''
createForm.description = '' createForm.description = ''
createForm.platform = 'anthropic' createForm.platform = 'anthropic'
@@ -2247,6 +2290,10 @@ const handleEdit = async (group: AdminGroup) => {
} }
const closeEditModal = () => { const closeEditModal = () => {
editModelRoutingRules.value.forEach((rule) => {
accountSearchRunner.clearKey(getEditRuleSearchKey(rule))
})
clearAllAccountSearchState()
showEditModal.value = false showEditModal.value = false
editingGroup.value = null editingGroup.value = null
editModelRoutingRules.value = [] editModelRoutingRules.value = []
@@ -2382,5 +2429,7 @@ onMounted(() => {
onUnmounted(() => { onUnmounted(() => {
document.removeEventListener('click', handleClickOutside) document.removeEventListener('click', handleClickOutside)
accountSearchRunner.clearAll()
clearAllAccountSearchState()
}) })
</script> </script>

View File

@@ -94,57 +94,44 @@ const exportToExcel = async () => {
if (exporting.value) return; exporting.value = true; exportProgress.show = true if (exporting.value) return; exporting.value = true; exportProgress.show = true
const c = new AbortController(); exportAbortController = c const c = new AbortController(); exportAbortController = c
try { try {
const all: AdminUsageLog[] = []; let p = 1; let total = pagination.total let p = 1; let total = pagination.total; let exportedCount = 0
const XLSX = await import('xlsx')
const headers = [
t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'),
t('admin.usage.account'), t('usage.model'), t('usage.reasoningEffort'), t('admin.usage.group'),
t('usage.type'),
t('admin.usage.inputTokens'), t('admin.usage.outputTokens'),
t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'),
t('admin.usage.inputCost'), t('admin.usage.outputCost'),
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
t('usage.rate'), t('usage.accountMultiplier'), t('usage.original'), t('usage.userBilled'), t('usage.accountBilled'),
t('usage.firstToken'), t('usage.duration'),
t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
]
const ws = XLSX.utils.aoa_to_sheet([headers])
while (true) { while (true) {
const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal }) const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total } if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }
if (res.items?.length) all.push(...res.items) const rows = (res.items || []).map((log: AdminUsageLog) => [
exportProgress.current = all.length; exportProgress.progress = total > 0 ? Math.min(100, Math.round(all.length/total*100)) : 0 log.created_at, log.user?.email || '', log.api_key?.name || '', log.account?.name || '', log.model,
if (all.length >= total || res.items.length < 100) break; p++ formatReasoningEffort(log.reasoning_effort), log.group?.name || '', log.stream ? t('usage.stream') : t('usage.sync'),
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
log.rate_multiplier?.toFixed(2) || '1.00', (log.account_rate_multiplier ?? 1).toFixed(2),
log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms,
log.request_id || '', log.user_agent || '', log.ip_address || ''
])
if (rows.length) {
XLSX.utils.sheet_add_aoa(ws, rows, { origin: -1 })
}
exportedCount += rows.length
exportProgress.current = exportedCount
exportProgress.progress = total > 0 ? Math.min(100, Math.round(exportedCount / total * 100)) : 0
if (exportedCount >= total || res.items.length < 100) break; p++
} }
if(!c.signal.aborted) { if(!c.signal.aborted) {
const XLSX = await import('xlsx')
const headers = [
t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'),
t('admin.usage.account'), t('usage.model'), t('usage.reasoningEffort'), t('admin.usage.group'),
t('usage.type'),
t('admin.usage.inputTokens'), t('admin.usage.outputTokens'),
t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'),
t('admin.usage.inputCost'), t('admin.usage.outputCost'),
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
t('usage.rate'), t('usage.accountMultiplier'), t('usage.original'), t('usage.userBilled'), t('usage.accountBilled'),
t('usage.firstToken'), t('usage.duration'),
t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
]
const rows = all.map(log => [
log.created_at,
log.user?.email || '',
log.api_key?.name || '',
log.account?.name || '',
log.model,
formatReasoningEffort(log.reasoning_effort),
log.group?.name || '',
log.stream ? t('usage.stream') : t('usage.sync'),
log.input_tokens,
log.output_tokens,
log.cache_read_tokens,
log.cache_creation_tokens,
log.input_cost?.toFixed(6) || '0.000000',
log.output_cost?.toFixed(6) || '0.000000',
log.cache_read_cost?.toFixed(6) || '0.000000',
log.cache_creation_cost?.toFixed(6) || '0.000000',
log.rate_multiplier?.toFixed(2) || '1.00',
(log.account_rate_multiplier ?? 1).toFixed(2),
log.total_cost?.toFixed(6) || '0.000000',
log.actual_cost?.toFixed(6) || '0.000000',
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6),
log.first_token_ms ?? '',
log.duration_ms,
log.request_id || '',
log.user_agent || '',
log.ip_address || ''
])
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
const wb = XLSX.utils.book_new() const wb = XLSX.utils.book_new()
XLSX.utils.book_append_sheet(wb, ws, 'Usage') XLSX.utils.book_append_sheet(wb, ws, 'Usage')
saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${filters.value.start_date}_to_${filters.value.end_date}.xlsx`) saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${filters.value.start_date}_to_${filters.value.end_date}.xlsx`)