Merge branch 'test' into release
This commit is contained in:
@@ -406,6 +406,14 @@ gateway:
|
||||
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
||||
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
||||
|
||||
**网关防御纵深建议(重点)**
|
||||
|
||||
- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。
|
||||
- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。
|
||||
- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。
|
||||
- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。
|
||||
- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
|
||||
|
||||
**⚠️ 安全警告:HTTP URL 配置**
|
||||
|
||||
当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
|
||||
|
||||
@@ -308,6 +308,12 @@ type GatewayConfig struct {
|
||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||
// 请求体最大字节数,用于网关请求体大小限制
|
||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
|
||||
// 代理探测响应体读取上限(字节)
|
||||
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
|
||||
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
|
||||
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
|
||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||
@@ -1059,6 +1065,9 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.MaxBodySize <= 0 {
|
||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||
}
|
||||
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||
}
|
||||
if c.Gateway.SoraMaxBodySize < 0 {
|
||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||
}
|
||||
|
||||
@@ -1106,7 +1106,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
|
||||
@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("gateway.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse 返回Claude API格式的错误响应
|
||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "error", parsed["type"])
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||
return false
|
||||
}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||
return true
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||
// 适用于 ACL / 风控等安全敏感场景。
|
||||
func GetTrustedClientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
require.NoError(t, r.SetTrustedProxies(nil))
|
||||
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
c.String(200, GetTrustedClientIP(c))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -1194,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||
if phaseFilter != "upstream" {
|
||||
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
||||
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
|
||||
}
|
||||
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
@@ -1208,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||
args = append(args, p)
|
||||
clauses = append(clauses, "platform = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
args = append(args, *filter.GroupID)
|
||||
clauses = append(clauses, "group_id = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
args = append(args, *filter.AccountID)
|
||||
clauses = append(clauses, "account_id = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
|
||||
}
|
||||
if phase := phaseFilter; phase != "" {
|
||||
args = append(args, phase)
|
||||
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
|
||||
}
|
||||
if filter != nil {
|
||||
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
||||
args = append(args, owner)
|
||||
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
|
||||
}
|
||||
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
||||
args = append(args, source)
|
||||
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
|
||||
}
|
||||
}
|
||||
if resolvedFilter != nil {
|
||||
args = append(args, *resolvedFilter)
|
||||
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
// View filter: errors vs excluded vs all.
|
||||
@@ -1246,46 +1246,46 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
switch view {
|
||||
case "", "errors":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||
case "excluded":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
|
||||
case "all":
|
||||
// no-op
|
||||
default:
|
||||
// treat unknown as default 'errors'
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||
}
|
||||
if len(filter.StatusCodes) > 0 {
|
||||
args = append(args, pq.Array(filter.StatusCodes))
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||
} else if filter.StatusCodesOther {
|
||||
// "Other" means: status codes not in the common list.
|
||||
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
||||
args = append(args, pq.Array(known))
|
||||
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||
}
|
||||
// Exact correlation keys (preferred for request↔upstream linkage).
|
||||
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
||||
args = append(args, rid)
|
||||
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
|
||||
}
|
||||
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
||||
args = append(args, crid)
|
||||
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + q + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
||||
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
|
||||
}
|
||||
|
||||
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
||||
like := "%" + userQuery + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "u.email ILIKE $"+n)
|
||||
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
|
||||
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
Query: "ACCESS_DENIED",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "e.request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.client_request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified client_request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.error_message ILIKE $") {
|
||||
t.Fatalf("where should include qualified error_message condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
UserQuery: "admin@",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
|
||||
t.Fatalf("where should include EXISTS user email condition: %s", where)
|
||||
}
|
||||
}
|
||||
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||
}
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
maxResponseBytes: maxResponseBytes,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
@@ -52,6 +58,7 @@ type proxyProbeService struct {
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
maxResponseBytes int64
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxResponseBytes := s.maxResponseBytes
|
||||
if maxResponseBytes <= 0 {
|
||||
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if int64(len(body)) > maxResponseBytes {
|
||||
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
|
||||
@@ -51,6 +51,9 @@ func ProvideRouter(
|
||||
if err := r.SetTrustedProxies(nil); err != nil {
|
||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||
}
|
||||
if cfg.Server.Mode == "release" {
|
||||
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
|
||||
@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
IPWhitelist: []string{"1.2.3.4"},
|
||||
}
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Register)
|
||||
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login)
|
||||
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.SendVerifyCode)
|
||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build integration
|
||||
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const authRouteRedisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startAuthRouteRedis(t, ctx)
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
const path = "/api/v1/auth/register"
|
||||
|
||||
for i := 1; i <= 6; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "198.51.100.10:23456"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if i <= 5 {
|
||||
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
|
||||
continue
|
||||
}
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureAuthRouteDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureAuthRouteDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if authRouteDockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过认证限流集成测试")
|
||||
}
|
||||
|
||||
func authRouteDockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func authRouteUserHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
RegisterAuthRoutes(
|
||||
v1,
|
||||
&handler.Handlers{
|
||||
Auth: &handler.AuthHandler{},
|
||||
Setting: &handler.SettingHandler{},
|
||||
},
|
||||
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
paths := []string{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/login/2fa",
|
||||
"/api/v1/auth/send-verify-code",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "203.0.113.10:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
|
||||
}
|
||||
}
|
||||
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2358,7 +2358,7 @@ type UpstreamHTTPResult struct {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
@@ -2366,21 +2366,28 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if isOAuth {
|
||||
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||
if uwErr == nil {
|
||||
respBody = unwrappedBody
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
} else {
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
@@ -2398,7 +2405,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
@@ -2406,6 +2413,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
}
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -3,10 +3,15 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
GeminiDebugResponseHeaders: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-RateLimit-Limit": []string{"60"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
|
||||
}
|
||||
|
||||
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
|
||||
@@ -313,7 +313,6 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco
|
||||
}
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if result.Matched {
|
||||
log.Warn("OpenAI codex_cli_only 允许官方客户端请求")
|
||||
return
|
||||
}
|
||||
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
||||
@@ -1277,6 +1276,29 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
if account != nil && account.Type == AccountTypeOAuth {
|
||||
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
|
||||
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
|
||||
setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: http.StatusForbidden,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: rejectMsg,
|
||||
Detail: rejectReason,
|
||||
})
|
||||
logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body)
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "forbidden_error",
|
||||
"message": rejectMsg,
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
|
||||
}
|
||||
|
||||
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1396,6 +1418,37 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func logOpenAIPassthroughInstructionsRejected(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
reqModel string,
|
||||
rejectReason string,
|
||||
body []byte,
|
||||
) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
accountID := int64(0)
|
||||
accountName := ""
|
||||
accountType := ""
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
accountName = strings.TrimSpace(account.Name)
|
||||
accountType = strings.TrimSpace(string(account.Type))
|
||||
}
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.String("account_name", accountName),
|
||||
zap.String("account_type", accountType),
|
||||
zap.String("request_model", strings.TrimSpace(reqModel)),
|
||||
zap.String("reject_reason", strings.TrimSpace(rejectReason)),
|
||||
}
|
||||
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
|
||||
logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
@@ -1688,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2318,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2877,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
||||
return normalized, changed, nil
|
||||
}
|
||||
|
||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||
if !strings.Contains(model, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() {
|
||||
return "instructions_missing"
|
||||
}
|
||||
if instructions.Type != gjson.String {
|
||||
return "instructions_not_string"
|
||||
}
|
||||
if strings.TrimSpace(instructions.String()) == "" {
|
||||
return "instructions_empty"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if reasoningEffort == "" {
|
||||
|
||||
@@ -103,7 +103,7 @@ func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
|
||||
func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) {
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
|
||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||
}, nil)
|
||||
|
||||
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
|
||||
require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
|
||||
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求"))
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||
}, body)
|
||||
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
|
||||
|
||||
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
||||
c.Request.Header.Set("Proxy-Authorization", "Basic abc")
|
||||
c.Request.Header.Set("X-Test", "keep")
|
||||
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstreamSSE := strings.Join([]string{
|
||||
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
|
||||
@@ -211,6 +211,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
||||
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
|
||||
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||
require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String()))
|
||||
// 其余关键字段保持原值。
|
||||
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
|
||||
@@ -235,6 +236,59 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
||||
require.NotContains(t, body, "\"name\":\"edit\"")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
|
||||
// Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。
|
||||
originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
upstream := &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "requires a non-empty instructions field")
|
||||
require.Nil(t, upstream.lastReq)
|
||||
|
||||
require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||
require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
38
backend/internal/service/upstream_response_limit.go
Normal file
38
backend/internal/service/upstream_response_limit.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||
|
||||
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||
|
||||
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||
return cfg.Gateway.UpstreamResponseReadMaxBytes
|
||||
}
|
||||
return defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, errors.New("response body is nil")
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
37
backend/internal/service/upstream_response_limit_test.go
Normal file
37
backend/internal/service/upstream_response_limit_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
|
||||
t.Run("use default when config missing", func(t *testing.T) {
|
||||
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
|
||||
})
|
||||
|
||||
t.Run("use configured value", func(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
|
||||
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
||||
t.Run("within limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ok"), body)
|
||||
})
|
||||
|
||||
t.Run("exceeds limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
|
||||
require.Nil(t, body)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||
})
|
||||
}
|
||||
@@ -146,6 +146,15 @@ gateway:
|
||||
# Max request body size in bytes (default: 100MB)
|
||||
# 请求体最大字节数(默认 100MB)
|
||||
max_body_size: 104857600
|
||||
# Max bytes to read for non-stream upstream responses (default: 8MB)
|
||||
# 非流式上游响应体读取上限(默认 8MB)
|
||||
upstream_response_read_max_bytes: 8388608
|
||||
# Max bytes to read for proxy probe responses (default: 1MB)
|
||||
# 代理探测响应体读取上限(默认 1MB)
|
||||
proxy_probe_response_read_max_bytes: 1048576
|
||||
# Enable Gemini upstream response header debug logs (default: false)
|
||||
# 是否开启 Gemini 上游响应头调试日志(默认 false)
|
||||
gemini_debug_response_headers: false
|
||||
# Sora max request body size in bytes (0=use max_body_size)
|
||||
# Sora 请求体最大字节数(0=使用 max_body_size)
|
||||
sora_max_body_size: 268435456
|
||||
|
||||
@@ -39,16 +39,6 @@ watch(
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
watch(
|
||||
() => appStore.siteName,
|
||||
(newName) => {
|
||||
if (newName) {
|
||||
document.title = `${newName} - AI API Gateway`
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
// Watch for authentication state and manage subscription data
|
||||
watch(
|
||||
() => authStore.isAuthenticated,
|
||||
|
||||
@@ -58,12 +58,16 @@ describe('ImportDataModal', () => {
|
||||
|
||||
const input = wrapper.find('input[type="file"]')
|
||||
const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
|
||||
Object.defineProperty(file, 'text', {
|
||||
value: () => Promise.resolve('invalid json')
|
||||
})
|
||||
Object.defineProperty(input.element, 'files', {
|
||||
value: [file]
|
||||
})
|
||||
|
||||
await input.trigger('change')
|
||||
await wrapper.find('form').trigger('submit')
|
||||
await Promise.resolve()
|
||||
|
||||
expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed')
|
||||
})
|
||||
|
||||
@@ -58,12 +58,16 @@ describe('Proxy ImportDataModal', () => {
|
||||
|
||||
const input = wrapper.find('input[type="file"]')
|
||||
const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
|
||||
Object.defineProperty(file, 'text', {
|
||||
value: () => Promise.resolve('invalid json')
|
||||
})
|
||||
Object.defineProperty(input.element, 'files', {
|
||||
value: [file]
|
||||
})
|
||||
|
||||
await input.trigger('change')
|
||||
await wrapper.find('form').trigger('submit')
|
||||
await Promise.resolve()
|
||||
|
||||
expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed')
|
||||
})
|
||||
|
||||
@@ -164,10 +164,10 @@ export async function getUsage(id: number): Promise<AccountUsageInfo> {
|
||||
/**
|
||||
* Clear account rate limit status
|
||||
* @param id - Account ID
|
||||
* @returns Success confirmation
|
||||
* @returns Updated account
|
||||
*/
|
||||
export async function clearRateLimit(id: number): Promise<{ message: string }> {
|
||||
const { data } = await apiClient.post<{ message: string }>(
|
||||
export async function clearRateLimit(id: number): Promise<Account> {
|
||||
const { data } = await apiClient.post<Account>(
|
||||
`/admin/accounts/${id}/clear-rate-limit`
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -209,7 +209,7 @@
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="index"
|
||||
:key="getModelMappingKey(mapping)"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
@@ -654,6 +654,7 @@ import Select from '@/components/common/Select.vue'
|
||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
|
||||
interface Props {
|
||||
show: boolean
|
||||
@@ -695,6 +696,7 @@ const baseUrl = ref('')
|
||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||
const allowedModels = ref<string[]>([])
|
||||
const modelMappings = ref<ModelMapping[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
|
||||
const selectedErrorCodes = ref<number[]>([])
|
||||
const customErrorCodeInput = ref<number | null>(null)
|
||||
const interceptWarmupRequests = ref(false)
|
||||
|
||||
@@ -714,7 +714,7 @@
|
||||
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in antigravityModelMappings"
|
||||
:key="index"
|
||||
:key="getAntigravityModelMappingKey(mapping)"
|
||||
class="space-y-1"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
@@ -966,7 +966,7 @@
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="index"
|
||||
:key="getModelMappingKey(mapping)"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
@@ -1225,7 +1225,7 @@
|
||||
<div v-if="tempUnschedRules.length > 0" class="space-y-3">
|
||||
<div
|
||||
v-for="(rule, index) in tempUnschedRules"
|
||||
:key="index"
|
||||
:key="getTempUnschedRuleKey(rule)"
|
||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||
>
|
||||
<div class="mb-2 flex items-center justify-between">
|
||||
@@ -2097,6 +2097,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||
|
||||
// Type for exposed OAuthAuthorizationFlow component
|
||||
@@ -2227,6 +2228,9 @@ const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-antigravity-model-mapping')
|
||||
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('create-temp-unsched-rule')
|
||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
||||
const geminiAIStudioOAuthEnabled = ref(false)
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@
|
||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in modelMappings"
|
||||
:key="index"
|
||||
:key="getModelMappingKey(mapping)"
|
||||
class="flex items-center gap-2"
|
||||
>
|
||||
<input
|
||||
@@ -417,7 +417,7 @@
|
||||
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
||||
<div
|
||||
v-for="(mapping, index) in antigravityModelMappings"
|
||||
:key="index"
|
||||
:key="getAntigravityModelMappingKey(mapping)"
|
||||
class="space-y-1"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
@@ -542,7 +542,7 @@
|
||||
<div v-if="tempUnschedRules.length > 0" class="space-y-3">
|
||||
<div
|
||||
v-for="(rule, index) in tempUnschedRules"
|
||||
:key="index"
|
||||
:key="getTempUnschedRuleKey(rule)"
|
||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||
>
|
||||
<div class="mb-2 flex items-center justify-between">
|
||||
@@ -1093,6 +1093,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import {
|
||||
getPresetMappingsByPlatform,
|
||||
commonErrorCodes,
|
||||
@@ -1110,7 +1111,7 @@ interface Props {
|
||||
const props = defineProps<Props>()
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
updated: []
|
||||
updated: [account: Account]
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
@@ -1158,6 +1159,9 @@ const antigravityWhitelistModels = ref<string[]>([])
|
||||
const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||
const tempUnschedEnabled = ref(false)
|
||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-model-mapping')
|
||||
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
|
||||
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
|
||||
|
||||
// Mixed channel warning dialog state
|
||||
const showMixedChannelWarning = ref(false)
|
||||
@@ -1845,9 +1849,9 @@ const handleSubmit = async () => {
|
||||
updatePayload.extra = newExtra
|
||||
}
|
||||
|
||||
await adminAPI.accounts.update(props.account.id, updatePayload)
|
||||
const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
|
||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||
emit('updated')
|
||||
emit('updated', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
// Handle 409 mixed_channel_warning - show confirmation dialog
|
||||
@@ -1875,9 +1879,9 @@ const handleMixedChannelConfirm = async () => {
|
||||
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
|
||||
submitting.value = true
|
||||
try {
|
||||
await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
|
||||
const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
|
||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||
emit('updated')
|
||||
emit('updated', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||
|
||||
@@ -143,6 +143,24 @@ const handleClose = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const readFileAsText = async (sourceFile: File): Promise<string> => {
|
||||
if (typeof sourceFile.text === 'function') {
|
||||
return sourceFile.text()
|
||||
}
|
||||
|
||||
if (typeof sourceFile.arrayBuffer === 'function') {
|
||||
const buffer = await sourceFile.arrayBuffer()
|
||||
return new TextDecoder().decode(buffer)
|
||||
}
|
||||
|
||||
return await new Promise<string>((resolve, reject) => {
|
||||
const reader = new FileReader()
|
||||
reader.onload = () => resolve(String(reader.result ?? ''))
|
||||
reader.onerror = () => reject(reader.error || new Error('Failed to read file'))
|
||||
reader.readAsText(sourceFile)
|
||||
})
|
||||
}
|
||||
|
||||
const handleImport = async () => {
|
||||
if (!file.value) {
|
||||
appStore.showError(t('admin.accounts.dataImportSelectFile'))
|
||||
@@ -151,7 +169,7 @@ const handleImport = async () => {
|
||||
|
||||
importing.value = true
|
||||
try {
|
||||
const text = await file.value.text()
|
||||
const text = await readFileAsText(file.value)
|
||||
const dataPayload = JSON.parse(text)
|
||||
|
||||
const res = await adminAPI.accounts.importData({
|
||||
|
||||
@@ -216,7 +216,7 @@ interface Props {
|
||||
const props = defineProps<Props>()
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
reauthorized: []
|
||||
reauthorized: [account: Account]
|
||||
}>()
|
||||
|
||||
const appStore = useAppStore()
|
||||
@@ -370,10 +370,10 @@ const handleExchangeCode = async () => {
|
||||
})
|
||||
|
||||
// Clear error status after successful re-authorization
|
||||
await adminAPI.accounts.clearError(props.account.id)
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized')
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
@@ -404,9 +404,9 @@ const handleExchangeCode = async () => {
|
||||
type: 'oauth',
|
||||
credentials
|
||||
})
|
||||
await adminAPI.accounts.clearError(props.account.id)
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized')
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
@@ -436,9 +436,9 @@ const handleExchangeCode = async () => {
|
||||
type: 'oauth',
|
||||
credentials
|
||||
})
|
||||
await adminAPI.accounts.clearError(props.account.id)
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized')
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
@@ -475,10 +475,10 @@ const handleExchangeCode = async () => {
|
||||
})
|
||||
|
||||
// Clear error status after successful re-authorization
|
||||
await adminAPI.accounts.clearError(props.account.id)
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized')
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||
@@ -518,10 +518,10 @@ const handleCookieAuth = async (sessionKey: string) => {
|
||||
})
|
||||
|
||||
// Clear error status after successful re-authorization
|
||||
await adminAPI.accounts.clearError(props.account.id)
|
||||
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||
|
||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||
emit('reauthorized')
|
||||
emit('reauthorized', updatedAccount)
|
||||
handleClose()
|
||||
} catch (error: any) {
|
||||
claudeOAuth.error.value =
|
||||
|
||||
@@ -143,6 +143,24 @@ const handleClose = () => {
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const readFileAsText = async (sourceFile: File): Promise<string> => {
|
||||
if (typeof sourceFile.text === 'function') {
|
||||
return sourceFile.text()
|
||||
}
|
||||
|
||||
if (typeof sourceFile.arrayBuffer === 'function') {
|
||||
const buffer = await sourceFile.arrayBuffer()
|
||||
return new TextDecoder().decode(buffer)
|
||||
}
|
||||
|
||||
return await new Promise<string>((resolve, reject) => {
|
||||
const reader = new FileReader()
|
||||
reader.onload = () => resolve(String(reader.result ?? ''))
|
||||
reader.onerror = () => reject(reader.error || new Error('Failed to read file'))
|
||||
reader.readAsText(sourceFile)
|
||||
})
|
||||
}
|
||||
|
||||
const handleImport = async () => {
|
||||
if (!file.value) {
|
||||
appStore.showError(t('admin.proxies.dataImportSelectFile'))
|
||||
@@ -151,7 +169,7 @@ const handleImport = async () => {
|
||||
|
||||
importing.value = true
|
||||
try {
|
||||
const text = await file.value.text()
|
||||
const text = await readFileAsText(file.value)
|
||||
const dataPayload = JSON.parse(text)
|
||||
|
||||
const res = await adminAPI.proxies.importData({ data: dataPayload })
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<template v-if="loading">
|
||||
<div v-for="i in 5" :key="i" class="rounded-lg border border-gray-200 bg-white p-4 dark:border-dark-700 dark:bg-dark-900">
|
||||
<div class="space-y-3">
|
||||
<div v-for="column in columns.filter(c => c.key !== 'actions')" :key="column.key" class="flex justify-between">
|
||||
<div v-for="column in dataColumns" :key="column.key" class="flex justify-between">
|
||||
<div class="h-4 w-20 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
|
||||
<div class="h-4 w-32 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
|
||||
</div>
|
||||
@@ -39,7 +39,7 @@
|
||||
>
|
||||
<div class="space-y-3">
|
||||
<div
|
||||
v-for="column in columns.filter(c => c.key !== 'actions')"
|
||||
v-for="column in dataColumns"
|
||||
:key="column.key"
|
||||
class="flex items-start justify-between gap-4"
|
||||
>
|
||||
@@ -439,10 +439,15 @@ const resolveRowKey = (row: any, index: number) => {
|
||||
return key ?? index
|
||||
}
|
||||
|
||||
const dataColumns = computed(() => props.columns.filter((column) => column.key !== 'actions'))
|
||||
const columnsSignature = computed(() =>
|
||||
props.columns.map((column) => `${column.key}:${column.sortable ? '1' : '0'}`).join('|')
|
||||
)
|
||||
|
||||
// 数据/列变化时重新检查滚动状态
|
||||
// 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环
|
||||
watch(
|
||||
[() => props.data.length, () => props.columns],
|
||||
[() => props.data.length, columnsSignature],
|
||||
async () => {
|
||||
await nextTick()
|
||||
checkScrollable()
|
||||
@@ -555,7 +560,7 @@ onMounted(() => {
|
||||
})
|
||||
|
||||
watch(
|
||||
() => props.columns,
|
||||
columnsSignature,
|
||||
() => {
|
||||
// If current sort key is no longer sortable/visible, fall back to default/persisted.
|
||||
const normalized = normalizeSortKey(sortKey.value)
|
||||
@@ -575,7 +580,7 @@ watch(
|
||||
}
|
||||
}
|
||||
},
|
||||
{ deep: true }
|
||||
{ flush: 'post' }
|
||||
)
|
||||
|
||||
watch(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
<div class="relative" ref="dropdownRef">
|
||||
<button
|
||||
@click="toggleDropdown"
|
||||
:disabled="switching"
|
||||
class="flex items-center gap-1.5 rounded-lg px-2 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 dark:text-gray-300 dark:hover:bg-dark-700"
|
||||
:title="currentLocale?.name"
|
||||
>
|
||||
@@ -23,6 +24,7 @@
|
||||
<button
|
||||
v-for="locale in availableLocales"
|
||||
:key="locale.code"
|
||||
:disabled="switching"
|
||||
@click="selectLocale(locale.code)"
|
||||
class="flex w-full items-center gap-2 px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-100 dark:text-gray-200 dark:hover:bg-dark-700"
|
||||
:class="{
|
||||
@@ -49,6 +51,7 @@ const { locale } = useI18n()
|
||||
|
||||
const isOpen = ref(false)
|
||||
const dropdownRef = ref<HTMLElement | null>(null)
|
||||
const switching = ref(false)
|
||||
|
||||
const currentLocaleCode = computed(() => locale.value)
|
||||
const currentLocale = computed(() => availableLocales.find((l) => l.code === locale.value))
|
||||
@@ -57,9 +60,18 @@ function toggleDropdown() {
|
||||
isOpen.value = !isOpen.value
|
||||
}
|
||||
|
||||
function selectLocale(code: string) {
|
||||
setLocale(code)
|
||||
async function selectLocale(code: string) {
|
||||
if (switching.value || code === currentLocaleCode.value) {
|
||||
isOpen.value = false
|
||||
return
|
||||
}
|
||||
switching.value = true
|
||||
try {
|
||||
await setLocale(code)
|
||||
isOpen.value = false
|
||||
} finally {
|
||||
switching.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function handleClickOutside(event: MouseEvent) {
|
||||
|
||||
@@ -84,8 +84,8 @@
|
||||
|
||||
<!-- Page numbers -->
|
||||
<button
|
||||
v-for="pageNum in visiblePages"
|
||||
:key="pageNum"
|
||||
v-for="(pageNum, index) in visiblePages"
|
||||
:key="`${pageNum}-${index}`"
|
||||
@click="typeof pageNum === 'number' && goToPage(pageNum)"
|
||||
:disabled="typeof pageNum !== 'number'"
|
||||
:class="[
|
||||
|
||||
@@ -66,8 +66,8 @@
|
||||
<!-- Progress bar -->
|
||||
<div v-if="toast.duration" class="h-1 bg-gray-100 dark:bg-dark-700">
|
||||
<div
|
||||
:class="['h-full transition-all', getProgressBarColor(toast.type)]"
|
||||
:style="{ width: `${getProgress(toast)}%` }"
|
||||
:class="['h-full toast-progress', getProgressBarColor(toast.type)]"
|
||||
:style="{ animationDuration: `${toast.duration}ms` }"
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -77,7 +77,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, onUnmounted } from 'vue'
|
||||
import { computed } from 'vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
|
||||
@@ -129,36 +129,25 @@ const getProgressBarColor = (type: string): string => {
|
||||
return colors[type] || colors.info
|
||||
}
|
||||
|
||||
const getProgress = (toast: any): number => {
|
||||
if (!toast.duration || !toast.startTime) return 100
|
||||
const elapsed = Date.now() - toast.startTime
|
||||
const progress = Math.max(0, 100 - (elapsed / toast.duration) * 100)
|
||||
return progress
|
||||
}
|
||||
|
||||
const removeToast = (id: string) => {
|
||||
appStore.hideToast(id)
|
||||
}
|
||||
|
||||
let intervalId: number | undefined
|
||||
|
||||
onMounted(() => {
|
||||
// Check for expired toasts every 100ms
|
||||
intervalId = window.setInterval(() => {
|
||||
const now = Date.now()
|
||||
toasts.value.forEach((toast) => {
|
||||
if (toast.duration && toast.startTime) {
|
||||
if (now - toast.startTime >= toast.duration) {
|
||||
removeToast(toast.id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}, 100)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
if (intervalId !== undefined) {
|
||||
clearInterval(intervalId)
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.toast-progress {
|
||||
width: 100%;
|
||||
animation-name: toast-progress-shrink;
|
||||
animation-timing-function: linear;
|
||||
animation-fill-mode: forwards;
|
||||
}
|
||||
|
||||
@keyframes toast-progress-shrink {
|
||||
from {
|
||||
width: 100%;
|
||||
}
|
||||
to {
|
||||
width: 0%;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -143,7 +143,7 @@
|
||||
<!-- Options (for select/multi_select) -->
|
||||
<div v-if="form.type === 'select' || form.type === 'multi_select'" class="space-y-2">
|
||||
<label class="input-label">{{ t('admin.users.attributes.options') }}</label>
|
||||
<div v-for="(option, index) in form.options" :key="index" class="flex items-center gap-2">
|
||||
<div v-for="(option, index) in form.options" :key="getOptionKey(option)" class="flex items-center gap-2">
|
||||
<input
|
||||
v-model="option.value"
|
||||
type="text"
|
||||
@@ -246,6 +246,7 @@ import BaseDialog from '@/components/common/BaseDialog.vue'
|
||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import Select from '@/components/common/Select.vue'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
@@ -270,6 +271,7 @@ const showEditModal = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const editingAttribute = ref<UserAttributeDefinition | null>(null)
|
||||
const deletingAttribute = ref<UserAttributeDefinition | null>(null)
|
||||
const getOptionKey = createStableObjectKeyResolver<UserAttributeOption>('user-attr-option')
|
||||
|
||||
const form = reactive({
|
||||
key: '',
|
||||
@@ -315,7 +317,7 @@ const openEditModal = (attr: UserAttributeDefinition) => {
|
||||
form.placeholder = attr.placeholder || ''
|
||||
form.required = attr.required
|
||||
form.enabled = attr.enabled
|
||||
form.options = attr.options ? [...attr.options] : []
|
||||
form.options = attr.options ? attr.options.map((opt) => ({ ...opt })) : []
|
||||
showEditModal.value = true
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, computed } from 'vue'
|
||||
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { totpAPI } from '@/api'
|
||||
@@ -107,6 +107,7 @@ const loading = ref(false)
|
||||
const error = ref('')
|
||||
const sendingCode = ref(false)
|
||||
const codeCooldown = ref(0)
|
||||
const cooldownTimer = ref<ReturnType<typeof setInterval> | null>(null)
|
||||
const form = ref({
|
||||
emailCode: '',
|
||||
password: ''
|
||||
@@ -139,10 +140,17 @@ const handleSendCode = async () => {
|
||||
appStore.showSuccess(t('profile.totp.codeSent'))
|
||||
// Start cooldown
|
||||
codeCooldown.value = 60
|
||||
const timer = setInterval(() => {
|
||||
if (cooldownTimer.value) {
|
||||
clearInterval(cooldownTimer.value)
|
||||
cooldownTimer.value = null
|
||||
}
|
||||
cooldownTimer.value = setInterval(() => {
|
||||
codeCooldown.value--
|
||||
if (codeCooldown.value <= 0) {
|
||||
clearInterval(timer)
|
||||
if (cooldownTimer.value) {
|
||||
clearInterval(cooldownTimer.value)
|
||||
cooldownTimer.value = null
|
||||
}
|
||||
}
|
||||
}, 1000)
|
||||
} catch (err: any) {
|
||||
@@ -176,4 +184,11 @@ const handleDisable = async () => {
|
||||
onMounted(() => {
|
||||
loadVerificationMethod()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
if (cooldownTimer.value) {
|
||||
clearInterval(cooldownTimer.value)
|
||||
cooldownTimer.value = null
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -175,7 +175,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, nextTick, watch, computed } from 'vue'
|
||||
import { ref, onMounted, onUnmounted, nextTick, watch, computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { totpAPI } from '@/api'
|
||||
@@ -198,6 +198,7 @@ const verifyForm = ref({ emailCode: '', password: '' })
|
||||
const verifyError = ref('')
|
||||
const sendingCode = ref(false)
|
||||
const codeCooldown = ref(0)
|
||||
const cooldownTimer = ref<ReturnType<typeof setInterval> | null>(null)
|
||||
|
||||
const setupLoading = ref(false)
|
||||
const setupData = ref<TotpSetupResponse | null>(null)
|
||||
@@ -338,10 +339,17 @@ const handleSendCode = async () => {
|
||||
appStore.showSuccess(t('profile.totp.codeSent'))
|
||||
// Start cooldown
|
||||
codeCooldown.value = 60
|
||||
const timer = setInterval(() => {
|
||||
if (cooldownTimer.value) {
|
||||
clearInterval(cooldownTimer.value)
|
||||
cooldownTimer.value = null
|
||||
}
|
||||
cooldownTimer.value = setInterval(() => {
|
||||
codeCooldown.value--
|
||||
if (codeCooldown.value <= 0) {
|
||||
clearInterval(timer)
|
||||
if (cooldownTimer.value) {
|
||||
clearInterval(cooldownTimer.value)
|
||||
cooldownTimer.value = null
|
||||
}
|
||||
}
|
||||
}, 1000)
|
||||
} catch (err: any) {
|
||||
@@ -397,4 +405,11 @@ const handleVerify = async () => {
|
||||
onMounted(() => {
|
||||
loadVerificationMethod()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
if (cooldownTimer.value) {
|
||||
clearInterval(cooldownTimer.value)
|
||||
cooldownTimer.value = null
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import { mount } from '@vue/test-utils'
|
||||
import TotpSetupModal from '@/components/user/profile/TotpSetupModal.vue'
|
||||
import TotpDisableDialog from '@/components/user/profile/TotpDisableDialog.vue'
|
||||
|
||||
const mocks = vi.hoisted(() => ({
|
||||
showSuccess: vi.fn(),
|
||||
showError: vi.fn(),
|
||||
getVerificationMethod: vi.fn(),
|
||||
sendVerifyCode: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({
|
||||
t: (key: string) => key
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/app', () => ({
|
||||
useAppStore: () => ({
|
||||
showSuccess: mocks.showSuccess,
|
||||
showError: mocks.showError
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/api', () => ({
|
||||
totpAPI: {
|
||||
getVerificationMethod: mocks.getVerificationMethod,
|
||||
sendVerifyCode: mocks.sendVerifyCode,
|
||||
initiateSetup: vi.fn(),
|
||||
enable: vi.fn(),
|
||||
disable: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
const flushPromises = async () => {
|
||||
await Promise.resolve()
|
||||
await Promise.resolve()
|
||||
}
|
||||
|
||||
describe('TOTP 弹窗定时器清理', () => {
|
||||
let intervalSeed = 1000
|
||||
let setIntervalSpy: ReturnType<typeof vi.spyOn>
|
||||
let clearIntervalSpy: ReturnType<typeof vi.spyOn>
|
||||
|
||||
beforeEach(() => {
|
||||
intervalSeed = 1000
|
||||
mocks.showSuccess.mockReset()
|
||||
mocks.showError.mockReset()
|
||||
mocks.getVerificationMethod.mockReset()
|
||||
mocks.sendVerifyCode.mockReset()
|
||||
|
||||
mocks.getVerificationMethod.mockResolvedValue({ method: 'email' })
|
||||
mocks.sendVerifyCode.mockResolvedValue({ success: true })
|
||||
|
||||
setIntervalSpy = vi.spyOn(window, 'setInterval').mockImplementation(((handler: TimerHandler) => {
|
||||
void handler
|
||||
intervalSeed += 1
|
||||
return intervalSeed as unknown as number
|
||||
}) as typeof window.setInterval)
|
||||
clearIntervalSpy = vi.spyOn(window, 'clearInterval')
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
setIntervalSpy.mockRestore()
|
||||
clearIntervalSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('TotpSetupModal 卸载时清理倒计时定时器', async () => {
|
||||
const wrapper = mount(TotpSetupModal)
|
||||
await flushPromises()
|
||||
|
||||
const sendButton = wrapper
|
||||
.findAll('button')
|
||||
.find((button) => button.text().includes('profile.totp.sendCode'))
|
||||
|
||||
expect(sendButton).toBeTruthy()
|
||||
await sendButton!.trigger('click')
|
||||
await flushPromises()
|
||||
|
||||
expect(setIntervalSpy).toHaveBeenCalledTimes(1)
|
||||
const timerId = setIntervalSpy.mock.results[0]?.value
|
||||
|
||||
wrapper.unmount()
|
||||
|
||||
expect(clearIntervalSpy).toHaveBeenCalledWith(timerId)
|
||||
})
|
||||
|
||||
it('TotpDisableDialog 卸载时清理倒计时定时器', async () => {
|
||||
const wrapper = mount(TotpDisableDialog)
|
||||
await flushPromises()
|
||||
|
||||
const sendButton = wrapper
|
||||
.findAll('button')
|
||||
.find((button) => button.text().includes('profile.totp.sendCode'))
|
||||
|
||||
expect(sendButton).toBeTruthy()
|
||||
await sendButton!.trigger('click')
|
||||
await flushPromises()
|
||||
|
||||
expect(setIntervalSpy).toHaveBeenCalledTimes(1)
|
||||
const timerId = setIntervalSpy.mock.results[0]?.value
|
||||
|
||||
wrapper.unmount()
|
||||
|
||||
expect(clearIntervalSpy).toHaveBeenCalledWith(timerId)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,100 @@
|
||||
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
|
||||
|
||||
const flushPromises = () => Promise.resolve()
|
||||
|
||||
describe('useKeyedDebouncedSearch', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('为不同 key 独立防抖触发搜索', async () => {
|
||||
const search = vi.fn().mockResolvedValue([])
|
||||
const onSuccess = vi.fn()
|
||||
|
||||
const searcher = useKeyedDebouncedSearch<string[]>({
|
||||
delay: 100,
|
||||
search,
|
||||
onSuccess
|
||||
})
|
||||
|
||||
searcher.trigger('a', 'foo')
|
||||
searcher.trigger('b', 'bar')
|
||||
|
||||
expect(search).not.toHaveBeenCalled()
|
||||
|
||||
vi.advanceTimersByTime(100)
|
||||
await flushPromises()
|
||||
|
||||
expect(search).toHaveBeenCalledTimes(2)
|
||||
expect(search).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'foo',
|
||||
expect.objectContaining({ key: 'a', signal: expect.any(AbortSignal) })
|
||||
)
|
||||
expect(search).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'bar',
|
||||
expect.objectContaining({ key: 'b', signal: expect.any(AbortSignal) })
|
||||
)
|
||||
expect(onSuccess).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('同 key 新请求会取消旧请求并忽略过期响应', async () => {
|
||||
const resolves: Array<(value: string[]) => void> = []
|
||||
const search = vi.fn().mockImplementation(
|
||||
() => new Promise<string[]>((resolve) => {
|
||||
resolves.push(resolve)
|
||||
})
|
||||
)
|
||||
const onSuccess = vi.fn()
|
||||
|
||||
const searcher = useKeyedDebouncedSearch<string[]>({
|
||||
delay: 50,
|
||||
search,
|
||||
onSuccess
|
||||
})
|
||||
|
||||
searcher.trigger('rule-1', 'first')
|
||||
vi.advanceTimersByTime(50)
|
||||
await flushPromises()
|
||||
|
||||
searcher.trigger('rule-1', 'second')
|
||||
vi.advanceTimersByTime(50)
|
||||
await flushPromises()
|
||||
|
||||
expect(search).toHaveBeenCalledTimes(2)
|
||||
|
||||
resolves[1](['second'])
|
||||
await flushPromises()
|
||||
expect(onSuccess).toHaveBeenCalledTimes(1)
|
||||
expect(onSuccess).toHaveBeenLastCalledWith('rule-1', ['second'])
|
||||
|
||||
resolves[0](['first'])
|
||||
await flushPromises()
|
||||
expect(onSuccess).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('clearKey 会取消未执行任务', () => {
|
||||
const search = vi.fn().mockResolvedValue([])
|
||||
const onSuccess = vi.fn()
|
||||
|
||||
const searcher = useKeyedDebouncedSearch<string[]>({
|
||||
delay: 100,
|
||||
search,
|
||||
onSuccess
|
||||
})
|
||||
|
||||
searcher.trigger('a', 'foo')
|
||||
searcher.clearKey('a')
|
||||
|
||||
vi.advanceTimersByTime(100)
|
||||
|
||||
expect(search).not.toHaveBeenCalled()
|
||||
expect(onSuccess).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
103
frontend/src/composables/useKeyedDebouncedSearch.ts
Normal file
103
frontend/src/composables/useKeyedDebouncedSearch.ts
Normal file
@@ -0,0 +1,103 @@
|
||||
import { getCurrentInstance, onUnmounted } from 'vue'
|
||||
|
||||
export interface KeyedDebouncedSearchContext {
|
||||
key: string
|
||||
signal: AbortSignal
|
||||
}
|
||||
|
||||
interface UseKeyedDebouncedSearchOptions<T> {
|
||||
delay?: number
|
||||
search: (keyword: string, context: KeyedDebouncedSearchContext) => Promise<T>
|
||||
onSuccess: (key: string, result: T) => void
|
||||
onError?: (key: string, error: unknown) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 多实例隔离的防抖搜索:每个 key 有独立的防抖、请求取消与过期响应保护。
|
||||
*/
|
||||
export function useKeyedDebouncedSearch<T>(options: UseKeyedDebouncedSearchOptions<T>) {
|
||||
const delay = options.delay ?? 300
|
||||
const timers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||
const controllers = new Map<string, AbortController>()
|
||||
const versions = new Map<string, number>()
|
||||
|
||||
const clearKey = (key: string) => {
|
||||
const timer = timers.get(key)
|
||||
if (timer) {
|
||||
clearTimeout(timer)
|
||||
timers.delete(key)
|
||||
}
|
||||
|
||||
const controller = controllers.get(key)
|
||||
if (controller) {
|
||||
controller.abort()
|
||||
controllers.delete(key)
|
||||
}
|
||||
|
||||
versions.delete(key)
|
||||
}
|
||||
|
||||
const clearAll = () => {
|
||||
const allKeys = new Set<string>([
|
||||
...timers.keys(),
|
||||
...controllers.keys(),
|
||||
...versions.keys()
|
||||
])
|
||||
|
||||
allKeys.forEach((key) => clearKey(key))
|
||||
}
|
||||
|
||||
const trigger = (key: string, keyword: string) => {
|
||||
const nextVersion = (versions.get(key) ?? 0) + 1
|
||||
versions.set(key, nextVersion)
|
||||
|
||||
const existingTimer = timers.get(key)
|
||||
if (existingTimer) {
|
||||
clearTimeout(existingTimer)
|
||||
timers.delete(key)
|
||||
}
|
||||
|
||||
const inFlight = controllers.get(key)
|
||||
if (inFlight) {
|
||||
inFlight.abort()
|
||||
controllers.delete(key)
|
||||
}
|
||||
|
||||
const timer = setTimeout(async () => {
|
||||
timers.delete(key)
|
||||
|
||||
const controller = new AbortController()
|
||||
controllers.set(key, controller)
|
||||
const requestVersion = versions.get(key)
|
||||
|
||||
try {
|
||||
const result = await options.search(keyword, { key, signal: controller.signal })
|
||||
if (controller.signal.aborted) return
|
||||
if (versions.get(key) !== requestVersion) return
|
||||
options.onSuccess(key, result)
|
||||
} catch (error) {
|
||||
if (controller.signal.aborted) return
|
||||
if (versions.get(key) !== requestVersion) return
|
||||
options.onError?.(key, error)
|
||||
} finally {
|
||||
if (controllers.get(key) === controller) {
|
||||
controllers.delete(key)
|
||||
}
|
||||
}
|
||||
}, delay)
|
||||
|
||||
timers.set(key, timer)
|
||||
}
|
||||
|
||||
if (getCurrentInstance()) {
|
||||
onUnmounted(() => {
|
||||
clearAll()
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
trigger,
|
||||
clearKey,
|
||||
clearAll
|
||||
}
|
||||
}
|
||||
@@ -1,53 +1,83 @@
|
||||
import { createI18n } from 'vue-i18n'
|
||||
import en from './locales/en'
|
||||
import zh from './locales/zh'
|
||||
|
||||
type LocaleCode = 'en' | 'zh'
|
||||
|
||||
type LocaleMessages = Record<string, any>
|
||||
|
||||
const LOCALE_KEY = 'sub2api_locale'
|
||||
const DEFAULT_LOCALE: LocaleCode = 'en'
|
||||
|
||||
function getDefaultLocale(): string {
|
||||
// Check localStorage first
|
||||
const localeLoaders: Record<LocaleCode, () => Promise<{ default: LocaleMessages }>> = {
|
||||
en: () => import('./locales/en'),
|
||||
zh: () => import('./locales/zh')
|
||||
}
|
||||
|
||||
function isLocaleCode(value: string): value is LocaleCode {
|
||||
return value === 'en' || value === 'zh'
|
||||
}
|
||||
|
||||
function getDefaultLocale(): LocaleCode {
|
||||
const saved = localStorage.getItem(LOCALE_KEY)
|
||||
if (saved && ['en', 'zh'].includes(saved)) {
|
||||
if (saved && isLocaleCode(saved)) {
|
||||
return saved
|
||||
}
|
||||
|
||||
// Check browser language
|
||||
const browserLang = navigator.language.toLowerCase()
|
||||
if (browserLang.startsWith('zh')) {
|
||||
return 'zh'
|
||||
}
|
||||
|
||||
return 'en'
|
||||
return DEFAULT_LOCALE
|
||||
}
|
||||
|
||||
export const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: getDefaultLocale(),
|
||||
fallbackLocale: 'en',
|
||||
messages: {
|
||||
en,
|
||||
zh
|
||||
},
|
||||
fallbackLocale: DEFAULT_LOCALE,
|
||||
messages: {},
|
||||
// 禁用 HTML 消息警告 - 引导步骤使用富文本内容(driver.js 支持 HTML)
|
||||
// 这些内容是内部定义的,不存在 XSS 风险
|
||||
warnHtmlMessage: false
|
||||
})
|
||||
|
||||
export function setLocale(locale: string) {
|
||||
if (['en', 'zh'].includes(locale)) {
|
||||
i18n.global.locale.value = locale as 'en' | 'zh'
|
||||
const loadedLocales = new Set<LocaleCode>()
|
||||
|
||||
export async function loadLocaleMessages(locale: LocaleCode): Promise<void> {
|
||||
if (loadedLocales.has(locale)) {
|
||||
return
|
||||
}
|
||||
|
||||
const loader = localeLoaders[locale]
|
||||
const module = await loader()
|
||||
i18n.global.setLocaleMessage(locale, module.default)
|
||||
loadedLocales.add(locale)
|
||||
}
|
||||
|
||||
export async function initI18n(): Promise<void> {
|
||||
const current = getLocale()
|
||||
await loadLocaleMessages(current)
|
||||
document.documentElement.setAttribute('lang', current)
|
||||
}
|
||||
|
||||
export async function setLocale(locale: string): Promise<void> {
|
||||
if (!isLocaleCode(locale)) {
|
||||
return
|
||||
}
|
||||
|
||||
await loadLocaleMessages(locale)
|
||||
i18n.global.locale.value = locale
|
||||
localStorage.setItem(LOCALE_KEY, locale)
|
||||
document.documentElement.setAttribute('lang', locale)
|
||||
}
|
||||
}
|
||||
|
||||
export function getLocale(): string {
|
||||
return i18n.global.locale.value
|
||||
export function getLocale(): LocaleCode {
|
||||
const current = i18n.global.locale.value
|
||||
return isLocaleCode(current) ? current : DEFAULT_LOCALE
|
||||
}
|
||||
|
||||
export const availableLocales = [
|
||||
{ code: 'en', name: 'English', flag: '🇺🇸' },
|
||||
{ code: 'zh', name: '中文', flag: '🇨🇳' }
|
||||
]
|
||||
] as const
|
||||
|
||||
export default i18n
|
||||
|
||||
@@ -2,16 +2,17 @@ import { createApp } from 'vue'
|
||||
import { createPinia } from 'pinia'
|
||||
import App from './App.vue'
|
||||
import router from './router'
|
||||
import i18n from './i18n'
|
||||
import i18n, { initI18n } from './i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import './style.css'
|
||||
|
||||
async function bootstrap() {
|
||||
const app = createApp(App)
|
||||
const pinia = createPinia()
|
||||
app.use(pinia)
|
||||
|
||||
// Initialize settings from injected config BEFORE mounting (prevents flash)
|
||||
// This must happen after pinia is installed but before router and i18n
|
||||
import { useAppStore } from '@/stores/app'
|
||||
const appStore = useAppStore()
|
||||
appStore.initFromInjectedConfig()
|
||||
|
||||
@@ -20,10 +21,14 @@ if (appStore.siteName && appStore.siteName !== 'Sub2API') {
|
||||
document.title = `${appStore.siteName} - AI API Gateway`
|
||||
}
|
||||
|
||||
await initI18n()
|
||||
|
||||
app.use(router)
|
||||
app.use(i18n)
|
||||
|
||||
// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染
|
||||
router.isReady().then(() => {
|
||||
await router.isReady()
|
||||
app.mount('#app')
|
||||
})
|
||||
}
|
||||
|
||||
bootstrap()
|
||||
|
||||
25
frontend/src/router/__tests__/title.spec.ts
Normal file
25
frontend/src/router/__tests__/title.spec.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { resolveDocumentTitle } from '@/router/title'
|
||||
|
||||
describe('resolveDocumentTitle', () => {
|
||||
it('路由存在标题时,使用“路由标题 - 站点名”格式', () => {
|
||||
expect(resolveDocumentTitle('Usage Records', 'My Site')).toBe('Usage Records - My Site')
|
||||
})
|
||||
|
||||
it('路由无标题时,回退到站点名', () => {
|
||||
expect(resolveDocumentTitle(undefined, 'My Site')).toBe('My Site')
|
||||
})
|
||||
|
||||
it('站点名为空时,回退默认站点名', () => {
|
||||
expect(resolveDocumentTitle('Dashboard', '')).toBe('Dashboard - Sub2API')
|
||||
expect(resolveDocumentTitle(undefined, ' ')).toBe('Sub2API')
|
||||
})
|
||||
|
||||
it('站点名变更时仅影响后续路由标题计算', () => {
|
||||
const before = resolveDocumentTitle('Admin Dashboard', 'Alpha')
|
||||
const after = resolveDocumentTitle('Admin Dashboard', 'Beta')
|
||||
|
||||
expect(before).toBe('Admin Dashboard - Alpha')
|
||||
expect(after).toBe('Admin Dashboard - Beta')
|
||||
})
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import { useAuthStore } from '@/stores/auth'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { useNavigationLoadingState } from '@/composables/useNavigationLoading'
|
||||
import { useRoutePrefetch } from '@/composables/useRoutePrefetch'
|
||||
import { resolveDocumentTitle } from './title'
|
||||
|
||||
/**
|
||||
* Route definitions with lazy loading
|
||||
@@ -389,12 +390,7 @@ router.beforeEach((to, _from, next) => {
|
||||
|
||||
// Set page title
|
||||
const appStore = useAppStore()
|
||||
const siteName = appStore.siteName || 'Sub2API'
|
||||
if (to.meta.title) {
|
||||
document.title = `${to.meta.title} - ${siteName}`
|
||||
} else {
|
||||
document.title = siteName
|
||||
}
|
||||
document.title = resolveDocumentTitle(to.meta.title, appStore.siteName)
|
||||
|
||||
// Check if route requires authentication
|
||||
const requiresAuth = to.meta.requiresAuth !== false // Default to true
|
||||
|
||||
12
frontend/src/router/title.ts
Normal file
12
frontend/src/router/title.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。
|
||||
*/
|
||||
export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string {
|
||||
const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'Sub2API'
|
||||
|
||||
if (typeof routeTitle === 'string' && routeTitle.trim()) {
|
||||
return `${routeTitle.trim()} - ${normalizedSiteName}`
|
||||
}
|
||||
|
||||
return normalizedSiteName
|
||||
}
|
||||
37
frontend/src/utils/__tests__/stableObjectKey.spec.ts
Normal file
37
frontend/src/utils/__tests__/stableObjectKey.spec.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
|
||||
describe('createStableObjectKeyResolver', () => {
|
||||
it('对同一对象返回稳定 key', () => {
|
||||
const resolve = createStableObjectKeyResolver<{ value: string }>('rule')
|
||||
const obj = { value: 'a' }
|
||||
|
||||
const key1 = resolve(obj)
|
||||
const key2 = resolve(obj)
|
||||
|
||||
expect(key1).toBe(key2)
|
||||
expect(key1.startsWith('rule-')).toBe(true)
|
||||
})
|
||||
|
||||
it('不同对象返回不同 key', () => {
|
||||
const resolve = createStableObjectKeyResolver<{ value: string }>('rule')
|
||||
|
||||
const key1 = resolve({ value: 'a' })
|
||||
const key2 = resolve({ value: 'a' })
|
||||
|
||||
expect(key1).not.toBe(key2)
|
||||
})
|
||||
|
||||
it('不同 resolver 互不影响', () => {
|
||||
const resolveA = createStableObjectKeyResolver<{ id: number }>('a')
|
||||
const resolveB = createStableObjectKeyResolver<{ id: number }>('b')
|
||||
const obj = { id: 1 }
|
||||
|
||||
const keyA = resolveA(obj)
|
||||
const keyB = resolveB(obj)
|
||||
|
||||
expect(keyA).not.toBe(keyB)
|
||||
expect(keyA.startsWith('a-')).toBe(true)
|
||||
expect(keyB.startsWith('b-')).toBe(true)
|
||||
})
|
||||
})
|
||||
19
frontend/src/utils/stableObjectKey.ts
Normal file
19
frontend/src/utils/stableObjectKey.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
let globalStableObjectKeySeed = 0
|
||||
|
||||
/**
|
||||
* 为对象实例生成稳定 key(基于 WeakMap,不污染业务对象)
|
||||
*/
|
||||
export function createStableObjectKeyResolver<T extends object>(prefix = 'item') {
|
||||
const keyMap = new WeakMap<T, string>()
|
||||
|
||||
return (item: T): string => {
|
||||
const cached = keyMap.get(item)
|
||||
if (cached) {
|
||||
return cached
|
||||
}
|
||||
|
||||
const key = `${prefix}-${++globalStableObjectKeySeed}`
|
||||
keyMap.set(item, key)
|
||||
return key
|
||||
}
|
||||
}
|
||||
@@ -239,8 +239,8 @@
|
||||
<template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /></template>
|
||||
</TablePageLayout>
|
||||
<CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" />
|
||||
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="load" />
|
||||
<ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="load" />
|
||||
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="handleAccountUpdated" />
|
||||
<ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="handleAccountUpdated" />
|
||||
<AccountTestModal :show="showTest" :account="testingAcc" @close="closeTestModal" />
|
||||
<AccountStatsModal :show="showStats" :account="statsAcc" @close="closeStatsModal" />
|
||||
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
|
||||
@@ -694,6 +694,53 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => {
|
||||
}
|
||||
const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() }
|
||||
const handleDataImported = () => { showImportData.value = false; reload() }
|
||||
const accountMatchesCurrentFilters = (account: Account) => {
|
||||
if (params.platform && account.platform !== params.platform) return false
|
||||
if (params.type && account.type !== params.type) return false
|
||||
if (params.status) {
|
||||
if (params.status === 'rate_limited') {
|
||||
if (!account.rate_limit_reset_at) return false
|
||||
const resetAt = new Date(account.rate_limit_reset_at).getTime()
|
||||
if (!Number.isFinite(resetAt) || resetAt <= Date.now()) return false
|
||||
} else if (account.status !== params.status) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
const search = String(params.search || '').trim().toLowerCase()
|
||||
if (search && !account.name.toLowerCase().includes(search)) return false
|
||||
return true
|
||||
}
|
||||
const mergeRuntimeFields = (oldAccount: Account, updatedAccount: Account): Account => ({
|
||||
...updatedAccount,
|
||||
current_concurrency: updatedAccount.current_concurrency ?? oldAccount.current_concurrency,
|
||||
current_window_cost: updatedAccount.current_window_cost ?? oldAccount.current_window_cost,
|
||||
active_sessions: updatedAccount.active_sessions ?? oldAccount.active_sessions
|
||||
})
|
||||
const patchAccountInList = (updatedAccount: Account) => {
|
||||
const index = accounts.value.findIndex(account => account.id === updatedAccount.id)
|
||||
if (index === -1) return
|
||||
const mergedAccount = mergeRuntimeFields(accounts.value[index], updatedAccount)
|
||||
if (!accountMatchesCurrentFilters(mergedAccount)) {
|
||||
accounts.value = accounts.value.filter(account => account.id !== mergedAccount.id)
|
||||
selIds.value = selIds.value.filter(id => id !== mergedAccount.id)
|
||||
if (menu.acc?.id === mergedAccount.id) {
|
||||
menu.show = false
|
||||
menu.acc = null
|
||||
}
|
||||
return
|
||||
}
|
||||
const nextAccounts = [...accounts.value]
|
||||
nextAccounts[index] = mergedAccount
|
||||
accounts.value = nextAccounts
|
||||
if (edAcc.value?.id === mergedAccount.id) edAcc.value = mergedAccount
|
||||
if (reAuthAcc.value?.id === mergedAccount.id) reAuthAcc.value = mergedAccount
|
||||
if (tempUnschedAcc.value?.id === mergedAccount.id) tempUnschedAcc.value = mergedAccount
|
||||
if (deletingAcc.value?.id === mergedAccount.id) deletingAcc.value = mergedAccount
|
||||
if (menu.acc?.id === mergedAccount.id) menu.acc = mergedAccount
|
||||
}
|
||||
const handleAccountUpdated = (updatedAccount: Account) => {
|
||||
patchAccountInList(updatedAccount)
|
||||
}
|
||||
const formatExportTimestamp = () => {
|
||||
const now = new Date()
|
||||
const pad2 = (value: number) => String(value).padStart(2, '0')
|
||||
@@ -743,9 +790,32 @@ const closeReAuthModal = () => { showReAuth.value = false; reAuthAcc.value = nul
|
||||
const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true }
|
||||
const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true }
|
||||
const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true }
|
||||
const handleRefresh = async (a: Account) => { try { await adminAPI.accounts.refreshCredentials(a.id); load() } catch (error) { console.error('Failed to refresh credentials:', error) } }
|
||||
const handleResetStatus = async (a: Account) => { try { await adminAPI.accounts.clearError(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to reset status:', error) } }
|
||||
const handleClearRateLimit = async (a: Account) => { try { await adminAPI.accounts.clearRateLimit(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to clear rate limit:', error) } }
|
||||
const handleRefresh = async (a: Account) => {
|
||||
try {
|
||||
const updated = await adminAPI.accounts.refreshCredentials(a.id)
|
||||
patchAccountInList(updated)
|
||||
} catch (error) {
|
||||
console.error('Failed to refresh credentials:', error)
|
||||
}
|
||||
}
|
||||
const handleResetStatus = async (a: Account) => {
|
||||
try {
|
||||
const updated = await adminAPI.accounts.clearError(a.id)
|
||||
patchAccountInList(updated)
|
||||
appStore.showSuccess(t('common.success'))
|
||||
} catch (error) {
|
||||
console.error('Failed to reset status:', error)
|
||||
}
|
||||
}
|
||||
const handleClearRateLimit = async (a: Account) => {
|
||||
try {
|
||||
const updated = await adminAPI.accounts.clearRateLimit(a.id)
|
||||
patchAccountInList(updated)
|
||||
appStore.showSuccess(t('common.success'))
|
||||
} catch (error) {
|
||||
console.error('Failed to clear rate limit:', error)
|
||||
}
|
||||
}
|
||||
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
|
||||
const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } }
|
||||
const handleToggleSchedulable = async (a: Account) => {
|
||||
@@ -762,7 +832,17 @@ const handleToggleSchedulable = async (a: Account) => {
|
||||
}
|
||||
}
|
||||
const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true }
|
||||
const handleTempUnschedReset = async () => { if(!tempUnschedAcc.value) return; try { await adminAPI.accounts.clearError(tempUnschedAcc.value.id); showTempUnsched.value = false; tempUnschedAcc.value = null; load() } catch (error) { console.error('Failed to reset temp unscheduled:', error) } }
|
||||
const handleTempUnschedReset = async () => {
|
||||
if(!tempUnschedAcc.value) return
|
||||
try {
|
||||
const updated = await adminAPI.accounts.clearError(tempUnschedAcc.value.id)
|
||||
showTempUnsched.value = false
|
||||
tempUnschedAcc.value = null
|
||||
patchAccountInList(updated)
|
||||
} catch (error) {
|
||||
console.error('Failed to reset temp unscheduled:', error)
|
||||
}
|
||||
}
|
||||
const formatExpiresAt = (value: number | null) => {
|
||||
if (!value) return '-'
|
||||
return formatDateTime(
|
||||
|
||||
@@ -759,8 +759,8 @@
|
||||
<!-- 路由规则列表(仅在启用时显示) -->
|
||||
<div v-if="createForm.model_routing_enabled" class="space-y-3">
|
||||
<div
|
||||
v-for="(rule, index) in createModelRoutingRules"
|
||||
:key="index"
|
||||
v-for="rule in createModelRoutingRules"
|
||||
:key="getCreateRuleRenderKey(rule)"
|
||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||
>
|
||||
<div class="flex items-start gap-3">
|
||||
@@ -786,7 +786,7 @@
|
||||
{{ account.name }}
|
||||
<button
|
||||
type="button"
|
||||
@click="removeSelectedAccount(index, account.id, false)"
|
||||
@click="removeSelectedAccount(rule, account.id)"
|
||||
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
|
||||
>
|
||||
<Icon name="x" size="xs" />
|
||||
@@ -796,23 +796,23 @@
|
||||
<!-- 账号搜索输入框 -->
|
||||
<div class="relative account-search-container">
|
||||
<input
|
||||
v-model="accountSearchKeyword[`create-${index}`]"
|
||||
v-model="accountSearchKeyword[getCreateRuleSearchKey(rule)]"
|
||||
type="text"
|
||||
class="input text-sm"
|
||||
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
|
||||
@input="searchAccounts(`create-${index}`)"
|
||||
@focus="onAccountSearchFocus(index, false)"
|
||||
@input="searchAccountsByRule(rule)"
|
||||
@focus="onAccountSearchFocus(rule)"
|
||||
/>
|
||||
<!-- 搜索结果下拉框 -->
|
||||
<div
|
||||
v-if="showAccountDropdown[`create-${index}`] && accountSearchResults[`create-${index}`]?.length > 0"
|
||||
v-if="showAccountDropdown[getCreateRuleSearchKey(rule)] && accountSearchResults[getCreateRuleSearchKey(rule)]?.length > 0"
|
||||
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
||||
>
|
||||
<button
|
||||
v-for="account in accountSearchResults[`create-${index}`]"
|
||||
v-for="account in accountSearchResults[getCreateRuleSearchKey(rule)]"
|
||||
:key="account.id"
|
||||
type="button"
|
||||
@click="selectAccount(index, account, false)"
|
||||
@click="selectAccount(rule, account)"
|
||||
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
|
||||
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
|
||||
:disabled="rule.accounts.some(a => a.id === account.id)"
|
||||
@@ -827,7 +827,7 @@
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeCreateRoutingRule(index)"
|
||||
@click="removeCreateRoutingRule(rule)"
|
||||
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
|
||||
:title="t('admin.groups.modelRouting.removeRule')"
|
||||
>
|
||||
@@ -1439,8 +1439,8 @@
|
||||
<!-- 路由规则列表(仅在启用时显示) -->
|
||||
<div v-if="editForm.model_routing_enabled" class="space-y-3">
|
||||
<div
|
||||
v-for="(rule, index) in editModelRoutingRules"
|
||||
:key="index"
|
||||
v-for="rule in editModelRoutingRules"
|
||||
:key="getEditRuleRenderKey(rule)"
|
||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||
>
|
||||
<div class="flex items-start gap-3">
|
||||
@@ -1466,7 +1466,7 @@
|
||||
{{ account.name }}
|
||||
<button
|
||||
type="button"
|
||||
@click="removeSelectedAccount(index, account.id, true)"
|
||||
@click="removeSelectedAccount(rule, account.id, true)"
|
||||
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
|
||||
>
|
||||
<Icon name="x" size="xs" />
|
||||
@@ -1476,23 +1476,23 @@
|
||||
<!-- 账号搜索输入框 -->
|
||||
<div class="relative account-search-container">
|
||||
<input
|
||||
v-model="accountSearchKeyword[`edit-${index}`]"
|
||||
v-model="accountSearchKeyword[getEditRuleSearchKey(rule)]"
|
||||
type="text"
|
||||
class="input text-sm"
|
||||
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
|
||||
@input="searchAccounts(`edit-${index}`)"
|
||||
@focus="onAccountSearchFocus(index, true)"
|
||||
@input="searchAccountsByRule(rule, true)"
|
||||
@focus="onAccountSearchFocus(rule, true)"
|
||||
/>
|
||||
<!-- 搜索结果下拉框 -->
|
||||
<div
|
||||
v-if="showAccountDropdown[`edit-${index}`] && accountSearchResults[`edit-${index}`]?.length > 0"
|
||||
v-if="showAccountDropdown[getEditRuleSearchKey(rule)] && accountSearchResults[getEditRuleSearchKey(rule)]?.length > 0"
|
||||
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
||||
>
|
||||
<button
|
||||
v-for="account in accountSearchResults[`edit-${index}`]"
|
||||
v-for="account in accountSearchResults[getEditRuleSearchKey(rule)]"
|
||||
:key="account.id"
|
||||
type="button"
|
||||
@click="selectAccount(index, account, true)"
|
||||
@click="selectAccount(rule, account, true)"
|
||||
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
|
||||
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
|
||||
:disabled="rule.accounts.some(a => a.id === account.id)"
|
||||
@@ -1507,7 +1507,7 @@
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
@click="removeEditRoutingRule(index)"
|
||||
@click="removeEditRoutingRule(rule)"
|
||||
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
|
||||
:title="t('admin.groups.modelRouting.removeRule')"
|
||||
>
|
||||
@@ -1687,6 +1687,8 @@ import Select from '@/components/common/Select.vue'
|
||||
import PlatformIcon from '@/components/common/PlatformIcon.vue'
|
||||
import Icon from '@/components/icons/Icon.vue'
|
||||
import { VueDraggable } from 'vue-draggable-plus'
|
||||
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
@@ -1911,33 +1913,70 @@ const createModelRoutingRules = ref<ModelRoutingRule[]>([])
|
||||
// 编辑表单的模型路由规则
|
||||
const editModelRoutingRules = ref<ModelRoutingRule[]>([])
|
||||
|
||||
// 账号搜索相关状态
|
||||
const accountSearchKeyword = ref<Record<string, string>>({}) // 每个规则的搜索关键词 (key: "create-0" 或 "edit-0")
|
||||
const accountSearchResults = ref<Record<string, SimpleAccount[]>>({}) // 每个规则的搜索结果
|
||||
const showAccountDropdown = ref<Record<string, boolean>>({}) // 每个规则的下拉框显示状态
|
||||
let accountSearchTimeout: ReturnType<typeof setTimeout> | null = null
|
||||
// 规则对象稳定 key(避免使用 index 导致状态错位)
|
||||
const resolveCreateRuleKey = createStableObjectKeyResolver<ModelRoutingRule>('create-rule')
|
||||
const resolveEditRuleKey = createStableObjectKeyResolver<ModelRoutingRule>('edit-rule')
|
||||
|
||||
// 搜索账号(仅限 anthropic 平台)
|
||||
const searchAccounts = async (key: string) => {
|
||||
if (accountSearchTimeout) clearTimeout(accountSearchTimeout)
|
||||
accountSearchTimeout = setTimeout(async () => {
|
||||
const keyword = accountSearchKeyword.value[key] || ''
|
||||
try {
|
||||
const res = await adminAPI.accounts.list(1, 20, {
|
||||
const getCreateRuleRenderKey = (rule: ModelRoutingRule) => resolveCreateRuleKey(rule)
|
||||
const getEditRuleRenderKey = (rule: ModelRoutingRule) => resolveEditRuleKey(rule)
|
||||
|
||||
const getCreateRuleSearchKey = (rule: ModelRoutingRule) => `create-${resolveCreateRuleKey(rule)}`
|
||||
const getEditRuleSearchKey = (rule: ModelRoutingRule) => `edit-${resolveEditRuleKey(rule)}`
|
||||
|
||||
const getRuleSearchKey = (rule: ModelRoutingRule, isEdit: boolean = false) => {
|
||||
return isEdit ? getEditRuleSearchKey(rule) : getCreateRuleSearchKey(rule)
|
||||
}
|
||||
|
||||
// 账号搜索相关状态
|
||||
const accountSearchKeyword = ref<Record<string, string>>({})
|
||||
const accountSearchResults = ref<Record<string, SimpleAccount[]>>({})
|
||||
const showAccountDropdown = ref<Record<string, boolean>>({})
|
||||
|
||||
const clearAccountSearchStateByKey = (key: string) => {
|
||||
delete accountSearchKeyword.value[key]
|
||||
delete accountSearchResults.value[key]
|
||||
delete showAccountDropdown.value[key]
|
||||
}
|
||||
|
||||
const clearAllAccountSearchState = () => {
|
||||
accountSearchKeyword.value = {}
|
||||
accountSearchResults.value = {}
|
||||
showAccountDropdown.value = {}
|
||||
}
|
||||
|
||||
const accountSearchRunner = useKeyedDebouncedSearch<SimpleAccount[]>({
|
||||
delay: 300,
|
||||
search: async (keyword, { signal }) => {
|
||||
const res = await adminAPI.accounts.list(
|
||||
1,
|
||||
20,
|
||||
{
|
||||
search: keyword,
|
||||
platform: 'anthropic'
|
||||
})
|
||||
accountSearchResults.value[key] = res.items.map((a) => ({ id: a.id, name: a.name }))
|
||||
} catch {
|
||||
},
|
||||
{ signal }
|
||||
)
|
||||
return res.items.map((account) => ({ id: account.id, name: account.name }))
|
||||
},
|
||||
onSuccess: (key, result) => {
|
||||
accountSearchResults.value[key] = result
|
||||
},
|
||||
onError: (key) => {
|
||||
accountSearchResults.value[key] = []
|
||||
}
|
||||
}, 300)
|
||||
})
|
||||
|
||||
// 搜索账号(仅限 anthropic 平台)
|
||||
const searchAccounts = (key: string) => {
|
||||
accountSearchRunner.trigger(key, accountSearchKeyword.value[key] || '')
|
||||
}
|
||||
|
||||
const searchAccountsByRule = (rule: ModelRoutingRule, isEdit: boolean = false) => {
|
||||
searchAccounts(getRuleSearchKey(rule, isEdit))
|
||||
}
|
||||
|
||||
// 选择账号
|
||||
const selectAccount = (ruleIndex: number, account: SimpleAccount, isEdit: boolean = false) => {
|
||||
const rules = isEdit ? editModelRoutingRules.value : createModelRoutingRules.value
|
||||
const rule = rules[ruleIndex]
|
||||
const selectAccount = (rule: ModelRoutingRule, account: SimpleAccount, isEdit: boolean = false) => {
|
||||
if (!rule) return
|
||||
|
||||
// 检查是否已选择
|
||||
@@ -1946,15 +1985,13 @@ const selectAccount = (ruleIndex: number, account: SimpleAccount, isEdit: boolea
|
||||
}
|
||||
|
||||
// 清空搜索
|
||||
const key = `${isEdit ? 'edit' : 'create'}-${ruleIndex}`
|
||||
const key = getRuleSearchKey(rule, isEdit)
|
||||
accountSearchKeyword.value[key] = ''
|
||||
showAccountDropdown.value[key] = false
|
||||
}
|
||||
|
||||
// 移除已选账号
|
||||
const removeSelectedAccount = (ruleIndex: number, accountId: number, isEdit: boolean = false) => {
|
||||
const rules = isEdit ? editModelRoutingRules.value : createModelRoutingRules.value
|
||||
const rule = rules[ruleIndex]
|
||||
const removeSelectedAccount = (rule: ModelRoutingRule, accountId: number, _isEdit: boolean = false) => {
|
||||
if (!rule) return
|
||||
|
||||
rule.accounts = rule.accounts.filter(a => a.id !== accountId)
|
||||
@@ -1981,8 +2018,8 @@ const toggleEditScope = (scope: string) => {
|
||||
}
|
||||
|
||||
// 处理账号搜索输入框聚焦
|
||||
const onAccountSearchFocus = (ruleIndex: number, isEdit: boolean = false) => {
|
||||
const key = `${isEdit ? 'edit' : 'create'}-${ruleIndex}`
|
||||
const onAccountSearchFocus = (rule: ModelRoutingRule, isEdit: boolean = false) => {
|
||||
const key = getRuleSearchKey(rule, isEdit)
|
||||
showAccountDropdown.value[key] = true
|
||||
// 如果没有搜索结果,触发一次搜索
|
||||
if (!accountSearchResults.value[key]?.length) {
|
||||
@@ -1996,13 +2033,14 @@ const addCreateRoutingRule = () => {
|
||||
}
|
||||
|
||||
// 删除创建表单的路由规则
|
||||
const removeCreateRoutingRule = (index: number) => {
|
||||
const removeCreateRoutingRule = (rule: ModelRoutingRule) => {
|
||||
const index = createModelRoutingRules.value.indexOf(rule)
|
||||
if (index === -1) return
|
||||
|
||||
const key = getCreateRuleSearchKey(rule)
|
||||
accountSearchRunner.clearKey(key)
|
||||
clearAccountSearchStateByKey(key)
|
||||
createModelRoutingRules.value.splice(index, 1)
|
||||
// 清理相关的搜索状态
|
||||
const key = `create-${index}`
|
||||
delete accountSearchKeyword.value[key]
|
||||
delete accountSearchResults.value[key]
|
||||
delete showAccountDropdown.value[key]
|
||||
}
|
||||
|
||||
// 添加编辑表单的路由规则
|
||||
@@ -2011,13 +2049,14 @@ const addEditRoutingRule = () => {
|
||||
}
|
||||
|
||||
// 删除编辑表单的路由规则
|
||||
const removeEditRoutingRule = (index: number) => {
|
||||
const removeEditRoutingRule = (rule: ModelRoutingRule) => {
|
||||
const index = editModelRoutingRules.value.indexOf(rule)
|
||||
if (index === -1) return
|
||||
|
||||
const key = getEditRuleSearchKey(rule)
|
||||
accountSearchRunner.clearKey(key)
|
||||
clearAccountSearchStateByKey(key)
|
||||
editModelRoutingRules.value.splice(index, 1)
|
||||
// 清理相关的搜索状态
|
||||
const key = `edit-${index}`
|
||||
delete accountSearchKeyword.value[key]
|
||||
delete accountSearchResults.value[key]
|
||||
delete showAccountDropdown.value[key]
|
||||
}
|
||||
|
||||
// 将 UI 格式的路由规则转换为 API 格式
|
||||
@@ -2161,6 +2200,10 @@ const handlePageSizeChange = (pageSize: number) => {
|
||||
|
||||
const closeCreateModal = () => {
|
||||
showCreateModal.value = false
|
||||
createModelRoutingRules.value.forEach((rule) => {
|
||||
accountSearchRunner.clearKey(getCreateRuleSearchKey(rule))
|
||||
})
|
||||
clearAllAccountSearchState()
|
||||
createForm.name = ''
|
||||
createForm.description = ''
|
||||
createForm.platform = 'anthropic'
|
||||
@@ -2247,6 +2290,10 @@ const handleEdit = async (group: AdminGroup) => {
|
||||
}
|
||||
|
||||
const closeEditModal = () => {
|
||||
editModelRoutingRules.value.forEach((rule) => {
|
||||
accountSearchRunner.clearKey(getEditRuleSearchKey(rule))
|
||||
})
|
||||
clearAllAccountSearchState()
|
||||
showEditModal.value = false
|
||||
editingGroup.value = null
|
||||
editModelRoutingRules.value = []
|
||||
@@ -2382,5 +2429,7 @@ onMounted(() => {
|
||||
|
||||
onUnmounted(() => {
|
||||
document.removeEventListener('click', handleClickOutside)
|
||||
accountSearchRunner.clearAll()
|
||||
clearAllAccountSearchState()
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -94,15 +94,7 @@ const exportToExcel = async () => {
|
||||
if (exporting.value) return; exporting.value = true; exportProgress.show = true
|
||||
const c = new AbortController(); exportAbortController = c
|
||||
try {
|
||||
const all: AdminUsageLog[] = []; let p = 1; let total = pagination.total
|
||||
while (true) {
|
||||
const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
|
||||
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }
|
||||
if (res.items?.length) all.push(...res.items)
|
||||
exportProgress.current = all.length; exportProgress.progress = total > 0 ? Math.min(100, Math.round(all.length/total*100)) : 0
|
||||
if (all.length >= total || res.items.length < 100) break; p++
|
||||
}
|
||||
if(!c.signal.aborted) {
|
||||
let p = 1; let total = pagination.total; let exportedCount = 0
|
||||
const XLSX = await import('xlsx')
|
||||
const headers = [
|
||||
t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'),
|
||||
@@ -116,35 +108,30 @@ const exportToExcel = async () => {
|
||||
t('usage.firstToken'), t('usage.duration'),
|
||||
t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
|
||||
]
|
||||
const rows = all.map(log => [
|
||||
log.created_at,
|
||||
log.user?.email || '',
|
||||
log.api_key?.name || '',
|
||||
log.account?.name || '',
|
||||
log.model,
|
||||
formatReasoningEffort(log.reasoning_effort),
|
||||
log.group?.name || '',
|
||||
log.stream ? t('usage.stream') : t('usage.sync'),
|
||||
log.input_tokens,
|
||||
log.output_tokens,
|
||||
log.cache_read_tokens,
|
||||
log.cache_creation_tokens,
|
||||
log.input_cost?.toFixed(6) || '0.000000',
|
||||
log.output_cost?.toFixed(6) || '0.000000',
|
||||
log.cache_read_cost?.toFixed(6) || '0.000000',
|
||||
log.cache_creation_cost?.toFixed(6) || '0.000000',
|
||||
log.rate_multiplier?.toFixed(2) || '1.00',
|
||||
(log.account_rate_multiplier ?? 1).toFixed(2),
|
||||
log.total_cost?.toFixed(6) || '0.000000',
|
||||
log.actual_cost?.toFixed(6) || '0.000000',
|
||||
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6),
|
||||
log.first_token_ms ?? '',
|
||||
log.duration_ms,
|
||||
log.request_id || '',
|
||||
log.user_agent || '',
|
||||
log.ip_address || ''
|
||||
const ws = XLSX.utils.aoa_to_sheet([headers])
|
||||
while (true) {
|
||||
const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
|
||||
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }
|
||||
const rows = (res.items || []).map((log: AdminUsageLog) => [
|
||||
log.created_at, log.user?.email || '', log.api_key?.name || '', log.account?.name || '', log.model,
|
||||
formatReasoningEffort(log.reasoning_effort), log.group?.name || '', log.stream ? t('usage.stream') : t('usage.sync'),
|
||||
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
|
||||
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
|
||||
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
|
||||
log.rate_multiplier?.toFixed(2) || '1.00', (log.account_rate_multiplier ?? 1).toFixed(2),
|
||||
log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
|
||||
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms,
|
||||
log.request_id || '', log.user_agent || '', log.ip_address || ''
|
||||
])
|
||||
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
|
||||
if (rows.length) {
|
||||
XLSX.utils.sheet_add_aoa(ws, rows, { origin: -1 })
|
||||
}
|
||||
exportedCount += rows.length
|
||||
exportProgress.current = exportedCount
|
||||
exportProgress.progress = total > 0 ? Math.min(100, Math.round(exportedCount / total * 100)) : 0
|
||||
if (exportedCount >= total || res.items.length < 100) break; p++
|
||||
}
|
||||
if(!c.signal.aborted) {
|
||||
const wb = XLSX.utils.book_new()
|
||||
XLSX.utils.book_append_sheet(wb, ws, 'Usage')
|
||||
saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${filters.value.start_date}_to_${filters.value.end_date}.xlsx`)
|
||||
|
||||
Reference in New Issue
Block a user