Merge branch 'test' into release
This commit is contained in:
@@ -406,6 +406,14 @@ gateway:
|
|||||||
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
- `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For
|
||||||
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
- `turnstile.required` 在 release 模式强制启用 Turnstile
|
||||||
|
|
||||||
|
**网关防御纵深建议(重点)**
|
||||||
|
|
||||||
|
- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。
|
||||||
|
- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。
|
||||||
|
- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。
|
||||||
|
- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。
|
||||||
|
- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。
|
||||||
|
|
||||||
**⚠️ 安全警告:HTTP URL 配置**
|
**⚠️ 安全警告:HTTP URL 配置**
|
||||||
|
|
||||||
当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
|
当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置:
|
||||||
|
|||||||
@@ -308,6 +308,12 @@ type GatewayConfig struct {
|
|||||||
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
|
||||||
// 请求体最大字节数,用于网关请求体大小限制
|
// 请求体最大字节数,用于网关请求体大小限制
|
||||||
MaxBodySize int64 `mapstructure:"max_body_size"`
|
MaxBodySize int64 `mapstructure:"max_body_size"`
|
||||||
|
// 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大
|
||||||
|
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"`
|
||||||
|
// 代理探测响应体读取上限(字节)
|
||||||
|
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"`
|
||||||
|
// Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销)
|
||||||
|
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
|
||||||
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
// ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy)
|
||||||
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"`
|
||||||
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
|
||||||
@@ -1059,6 +1065,9 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
|
||||||
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
|
||||||
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
|
||||||
|
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
|
||||||
|
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
|
||||||
|
viper.SetDefault("gateway.gemini_debug_response_headers", false)
|
||||||
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
|
||||||
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
|
||||||
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
|
||||||
@@ -1465,6 +1474,12 @@ func (c *Config) Validate() error {
|
|||||||
if c.Gateway.MaxBodySize <= 0 {
|
if c.Gateway.MaxBodySize <= 0 {
|
||||||
return fmt.Errorf("gateway.max_body_size must be positive")
|
return fmt.Errorf("gateway.max_body_size must be positive")
|
||||||
}
|
}
|
||||||
|
if c.Gateway.UpstreamResponseReadMaxBytes <= 0 {
|
||||||
|
return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
|
||||||
|
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
|
||||||
|
}
|
||||||
if c.Gateway.SoraMaxBodySize < 0 {
|
if c.Gateway.SoraMaxBodySize < 0 {
|
||||||
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1106,7 +1106,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.AccountFromService(account))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||||
|
|||||||
@@ -418,8 +418,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("gateway.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -683,8 +687,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 错误响应已在Forward中处理,这里只记录日志
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("gateway.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("gateway.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1117,6 +1125,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
h.errorResponse(c, status, errType, message)
|
h.errorResponse(c, status, errType, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||||
|
func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||||
|
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// errorResponse 返回Claude API格式的错误响应
|
// errorResponse 返回Claude API格式的错误响应
|
||||||
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||||
c.JSON(status, gin.H{
|
c.JSON(status, gin.H{
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.True(t, wrote)
|
||||||
|
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "error", parsed["type"])
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.String(http.StatusTeapot, "already written")
|
||||||
|
|
||||||
|
h := &GatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.False(t, wrote)
|
||||||
|
require.Equal(t, http.StatusTeapot, w.Code)
|
||||||
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
|
}
|
||||||
@@ -365,8 +365,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Error response already handled in Forward, just log
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
reqLog.Error("openai.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
reqLog.Error("openai.forward_failed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -521,6 +525,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
|||||||
h.errorResponse(c, status, errType, message)
|
h.errorResponse(c, status, errType, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||||||
|
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||||||
|
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// errorResponse returns OpenAI API format error response
|
// errorResponse returns OpenAI API format error response
|
||||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||||
c.JSON(status, gin.H{
|
c.JSON(status, gin.H{
|
||||||
|
|||||||
@@ -105,6 +105,42 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
|||||||
assert.Equal(t, "test error", errorObj["message"])
|
assert.Equal(t, "test error", errorObj["message"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.True(t, wrote)
|
||||||
|
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.String(http.StatusTeapot, "already written")
|
||||||
|
|
||||||
|
h := &OpenAIGatewayHandler{}
|
||||||
|
wrote := h.ensureForwardErrorResponse(c, false)
|
||||||
|
|
||||||
|
require.False(t, wrote)
|
||||||
|
require.Equal(t, http.StatusTeapot, w.Code)
|
||||||
|
assert.Equal(t, "already written", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
|||||||
return normalizeIP(c.ClientIP())
|
return normalizeIP(c.ClientIP())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||||
|
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||||
|
// 适用于 ACL / 风控等安全敏感场景。
|
||||||
|
func GetTrustedClientIP(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return normalizeIP(c.ClientIP())
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||||
func normalizeIP(ip string) string {
|
func normalizeIP(ip string) string {
|
||||||
ip = strings.TrimSpace(ip)
|
ip = strings.TrimSpace(ip)
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
package ip
|
package ip
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,3 +51,25 @@ func TestIsPrivateIP(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
require.NoError(t, r.SetTrustedProxies(nil))
|
||||||
|
|
||||||
|
r.GET("/t", func(c *gin.Context) {
|
||||||
|
c.String(200, GetTrustedClientIP(c))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/t", nil)
|
||||||
|
req.RemoteAddr = "9.9.9.9:12345"
|
||||||
|
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||||
|
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||||
|
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, 200, w.Code)
|
||||||
|
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -1194,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
}
|
}
|
||||||
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||||
if phaseFilter != "upstream" {
|
if phaseFilter != "upstream" {
|
||||||
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||||
@@ -1208,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
}
|
}
|
||||||
if p := strings.TrimSpace(filter.Platform); p != "" {
|
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||||
args = append(args, p)
|
args = append(args, p)
|
||||||
clauses = append(clauses, "platform = $"+itoa(len(args)))
|
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||||
args = append(args, *filter.GroupID)
|
args = append(args, *filter.GroupID)
|
||||||
clauses = append(clauses, "group_id = $"+itoa(len(args)))
|
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||||
args = append(args, *filter.AccountID)
|
args = append(args, *filter.AccountID)
|
||||||
clauses = append(clauses, "account_id = $"+itoa(len(args)))
|
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if phase := phaseFilter; phase != "" {
|
if phase := phaseFilter; phase != "" {
|
||||||
args = append(args, phase)
|
args = append(args, phase)
|
||||||
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if filter != nil {
|
if filter != nil {
|
||||||
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
||||||
args = append(args, owner)
|
args = append(args, owner)
|
||||||
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
|
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
||||||
args = append(args, source)
|
args = append(args, source)
|
||||||
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
|
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if resolvedFilter != nil {
|
if resolvedFilter != nil {
|
||||||
args = append(args, *resolvedFilter)
|
args = append(args, *resolvedFilter)
|
||||||
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
|
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// View filter: errors vs excluded vs all.
|
// View filter: errors vs excluded vs all.
|
||||||
@@ -1246,46 +1246,46 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
|||||||
}
|
}
|
||||||
switch view {
|
switch view {
|
||||||
case "", "errors":
|
case "", "errors":
|
||||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||||
case "excluded":
|
case "excluded":
|
||||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
|
||||||
case "all":
|
case "all":
|
||||||
// no-op
|
// no-op
|
||||||
default:
|
default:
|
||||||
// treat unknown as default 'errors'
|
// treat unknown as default 'errors'
|
||||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||||
}
|
}
|
||||||
if len(filter.StatusCodes) > 0 {
|
if len(filter.StatusCodes) > 0 {
|
||||||
args = append(args, pq.Array(filter.StatusCodes))
|
args = append(args, pq.Array(filter.StatusCodes))
|
||||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||||
} else if filter.StatusCodesOther {
|
} else if filter.StatusCodesOther {
|
||||||
// "Other" means: status codes not in the common list.
|
// "Other" means: status codes not in the common list.
|
||||||
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
||||||
args = append(args, pq.Array(known))
|
args = append(args, pq.Array(known))
|
||||||
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
|
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||||
}
|
}
|
||||||
// Exact correlation keys (preferred for request↔upstream linkage).
|
// Exact correlation keys (preferred for request↔upstream linkage).
|
||||||
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
||||||
args = append(args, rid)
|
args = append(args, rid)
|
||||||
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
|
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
||||||
args = append(args, crid)
|
args = append(args, crid)
|
||||||
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
|
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
like := "%" + q + "%"
|
like := "%" + q + "%"
|
||||||
args = append(args, like)
|
args = append(args, like)
|
||||||
n := itoa(len(args))
|
n := itoa(len(args))
|
||||||
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
|
||||||
}
|
}
|
||||||
|
|
||||||
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
||||||
like := "%" + userQuery + "%"
|
like := "%" + userQuery + "%"
|
||||||
args = append(args, like)
|
args = append(args, like)
|
||||||
n := itoa(len(args))
|
n := itoa(len(args))
|
||||||
clauses = append(clauses, "u.email ILIKE $"+n)
|
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
|
||||||
}
|
}
|
||||||
|
|
||||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
|||||||
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
|
||||||
|
filter := &service.OpsErrorLogFilter{
|
||||||
|
Query: "ACCESS_DENIED",
|
||||||
|
}
|
||||||
|
|
||||||
|
where, args := buildOpsErrorLogsWhere(filter)
|
||||||
|
if where == "" {
|
||||||
|
t.Fatalf("where should not be empty")
|
||||||
|
}
|
||||||
|
if len(args) != 1 {
|
||||||
|
t.Fatalf("args len = %d, want 1", len(args))
|
||||||
|
}
|
||||||
|
if !strings.Contains(where, "e.request_id ILIKE $") {
|
||||||
|
t.Fatalf("where should include qualified request_id condition: %s", where)
|
||||||
|
}
|
||||||
|
if !strings.Contains(where, "e.client_request_id ILIKE $") {
|
||||||
|
t.Fatalf("where should include qualified client_request_id condition: %s", where)
|
||||||
|
}
|
||||||
|
if !strings.Contains(where, "e.error_message ILIKE $") {
|
||||||
|
t.Fatalf("where should include qualified error_message condition: %s", where)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
|
||||||
|
filter := &service.OpsErrorLogFilter{
|
||||||
|
UserQuery: "admin@",
|
||||||
|
}
|
||||||
|
|
||||||
|
where, args := buildOpsErrorLogsWhere(filter)
|
||||||
|
if where == "" {
|
||||||
|
t.Fatalf("where should not be empty")
|
||||||
|
}
|
||||||
|
if len(args) != 1 {
|
||||||
|
t.Fatalf("args len = %d, want 1", len(args))
|
||||||
|
}
|
||||||
|
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
|
||||||
|
t.Fatalf("where should include EXISTS user email condition: %s", where)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
insecure := false
|
insecure := false
|
||||||
allowPrivate := false
|
allowPrivate := false
|
||||||
validateResolvedIP := true
|
validateResolvedIP := true
|
||||||
|
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||||
|
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||||
|
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if insecure {
|
if insecure {
|
||||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||||
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
|||||||
insecureSkipVerify: insecure,
|
insecureSkipVerify: insecure,
|
||||||
allowPrivateHosts: allowPrivate,
|
allowPrivateHosts: allowPrivate,
|
||||||
validateResolvedIP: validateResolvedIP,
|
validateResolvedIP: validateResolvedIP,
|
||||||
|
maxResponseBytes: maxResponseBytes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultProxyProbeTimeout = 30 * time.Second
|
defaultProxyProbeTimeout = 30 * time.Second
|
||||||
|
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||||
)
|
)
|
||||||
|
|
||||||
// probeURLs 按优先级排列的探测 URL 列表
|
// probeURLs 按优先级排列的探测 URL 列表
|
||||||
@@ -52,6 +58,7 @@ type proxyProbeService struct {
|
|||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
allowPrivateHosts bool
|
allowPrivateHosts bool
|
||||||
validateResolvedIP bool
|
validateResolvedIP bool
|
||||||
|
maxResponseBytes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
|
|||||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxResponseBytes := s.maxResponseBytes
|
||||||
|
if maxResponseBytes <= 0 {
|
||||||
|
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||||
}
|
}
|
||||||
|
if int64(len(body)) > maxResponseBytes {
|
||||||
|
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||||
|
}
|
||||||
|
|
||||||
switch parser {
|
switch parser {
|
||||||
case "ip-api":
|
case "ip-api":
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ func ProvideRouter(
|
|||||||
if err := r.SetTrustedProxies(nil); err != nil {
|
if err := r.SetTrustedProxies(nil); err != nil {
|
||||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||||
}
|
}
|
||||||
|
if cfg.Server.Mode == "release" {
|
||||||
|
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
|||||||
// 检查 IP 限制(白名单/黑名单)
|
// 检查 IP 限制(白名单/黑名单)
|
||||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||||
clientIP := ip.GetClientIP(c)
|
clientIP := ip.GetTrustedClientIP(c)
|
||||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||||
if !allowed {
|
if !allowed {
|
||||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||||
|
|||||||
@@ -300,6 +300,57 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, w.Code)
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
user := &service.User{
|
||||||
|
ID: 7,
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Balance: 10,
|
||||||
|
Concurrency: 3,
|
||||||
|
}
|
||||||
|
apiKey := &service.APIKey{
|
||||||
|
ID: 100,
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "test-key",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
User: user,
|
||||||
|
IPWhitelist: []string{"1.2.3.4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKeyRepo := &stubApiKeyRepo{
|
||||||
|
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||||
|
if key != apiKey.Key {
|
||||||
|
return nil, service.ErrAPIKeyNotFound
|
||||||
|
}
|
||||||
|
clone := *apiKey
|
||||||
|
return &clone, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||||
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||||
|
router := gin.New()
|
||||||
|
require.NoError(t, router.SetTrustedProxies(nil))
|
||||||
|
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||||
|
router.GET("/t", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||||
|
req.RemoteAddr = "9.9.9.9:12345"
|
||||||
|
req.Header.Set("x-api-key", apiKey.Key)
|
||||||
|
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||||
|
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||||
|
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||||
|
}
|
||||||
|
|
||||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||||
|
|||||||
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
|
|||||||
// 公开接口
|
// 公开接口
|
||||||
auth := v1.Group("/auth")
|
auth := v1.Group("/auth")
|
||||||
{
|
{
|
||||||
auth.POST("/register", h.Auth.Register)
|
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||||
auth.POST("/login", h.Auth.Login)
|
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
}), h.Auth.Register)
|
||||||
|
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.Login)
|
||||||
|
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.Login2FA)
|
||||||
|
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||||
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
}), h.Auth.SendVerifyCode)
|
||||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||||
FailureMode: middleware.RateLimitFailClose,
|
FailureMode: middleware.RateLimitFailClose,
|
||||||
|
|||||||
@@ -0,0 +1,111 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const authRouteRedisImageTag = "redis:8.4-alpine"
|
||||||
|
|
||||||
|
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
rdb := startAuthRouteRedis(t, ctx)
|
||||||
|
|
||||||
|
router := newAuthRoutesTestRouter(rdb)
|
||||||
|
const path = "/api/v1/auth/register"
|
||||||
|
|
||||||
|
for i := 1; i <= 6; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.RemoteAddr = "198.51.100.10:23456"
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if i <= 5 {
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
|
||||||
|
require.Contains(t, w.Body.String(), "rate limit exceeded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||||
|
t.Helper()
|
||||||
|
ensureAuthRouteDockerAvailable(t)
|
||||||
|
|
||||||
|
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = redisContainer.Terminate(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
redisHost, err := redisContainer.Host(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
require.NoError(t, rdb.Ping(ctx).Err())
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
return rdb
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureAuthRouteDockerAvailable(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if authRouteDockerAvailable() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Skip("Docker 未启用,跳过认证限流集成测试")
|
||||||
|
}
|
||||||
|
|
||||||
|
func authRouteDockerAvailable() bool {
|
||||||
|
if os.Getenv("DOCKER_HOST") != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
socketCandidates := []string{
|
||||||
|
"/var/run/docker.sock",
|
||||||
|
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||||
|
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||||
|
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||||
|
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, socket := range socketCandidates {
|
||||||
|
if socket == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(socket); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func authRouteUserHomeDir() string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return home
|
||||||
|
}
|
||||||
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package routes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
v1 := router.Group("/api/v1")
|
||||||
|
|
||||||
|
RegisterAuthRoutes(
|
||||||
|
v1,
|
||||||
|
&handler.Handlers{
|
||||||
|
Auth: &handler.AuthHandler{},
|
||||||
|
Setting: &handler.SettingHandler{},
|
||||||
|
},
|
||||||
|
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
|
||||||
|
c.Next()
|
||||||
|
}),
|
||||||
|
redisClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: "127.0.0.1:1",
|
||||||
|
DialTimeout: 50 * time.Millisecond,
|
||||||
|
ReadTimeout: 50 * time.Millisecond,
|
||||||
|
WriteTimeout: 50 * time.Millisecond,
|
||||||
|
})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
router := newAuthRoutesTestRouter(rdb)
|
||||||
|
paths := []string{
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
"/api/v1/auth/login/2fa",
|
||||||
|
"/api/v1/auth/send-verify-code",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range paths {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.RemoteAddr = "203.0.113.10:12345"
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
|
||||||
|
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
|
|
||||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||||
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||||
|
return err
|
||||||
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
resp = retryResp
|
resp = retryResp
|
||||||
respBody, err = io.ReadAll(resp.Body)
|
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||||
|
return err
|
||||||
|
}
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||||
// Log response headers for debugging
|
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var parsed map[string]any
|
|
||||||
if isOAuth {
|
if isOAuth {
|
||||||
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||||
if uwErr == nil {
|
if uwErr == nil {
|
||||||
respBody = unwrappedBody
|
respBody = unwrappedBody
|
||||||
}
|
}
|
||||||
_ = json.Unmarshal(respBody, &parsed)
|
|
||||||
} else {
|
|
||||||
_ = json.Unmarshal(respBody, &parsed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||||
// Log response headers for debugging
|
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||||
}
|
}
|
||||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
|
||||||
|
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
|
|||||||
@@ -3,10 +3,15 @@ package service
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
svc := &GeminiMessagesCompatService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
GeminiDebugResponseHeaders: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/json"},
|
||||||
|
"X-RateLimit-Limit": []string{"60"},
|
||||||
|
},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||||
claudeReq := map[string]any{
|
claudeReq := map[string]any{
|
||||||
"model": "claude-haiku-4-5-20251001",
|
"model": "claude-haiku-4-5-20251001",
|
||||||
|
|||||||
@@ -313,7 +313,6 @@ func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Acco
|
|||||||
}
|
}
|
||||||
log := logger.FromContext(ctx).With(fields...)
|
log := logger.FromContext(ctx).With(fields...)
|
||||||
if result.Matched {
|
if result.Matched {
|
||||||
log.Warn("OpenAI codex_cli_only 允许官方客户端请求")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
||||||
@@ -1277,6 +1276,29 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
startTime time.Time,
|
startTime time.Time,
|
||||||
) (*OpenAIForwardResult, error) {
|
) (*OpenAIForwardResult, error) {
|
||||||
if account != nil && account.Type == AccountTypeOAuth {
|
if account != nil && account.Type == AccountTypeOAuth {
|
||||||
|
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
|
||||||
|
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
|
||||||
|
setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "")
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: http.StatusForbidden,
|
||||||
|
Passthrough: true,
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: rejectMsg,
|
||||||
|
Detail: rejectReason,
|
||||||
|
})
|
||||||
|
logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body)
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "forbidden_error",
|
||||||
|
"message": rejectMsg,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason)
|
||||||
|
}
|
||||||
|
|
||||||
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
|
normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1396,6 +1418,37 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func logOpenAIPassthroughInstructionsRejected(
|
||||||
|
ctx context.Context,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
reqModel string,
|
||||||
|
rejectReason string,
|
||||||
|
body []byte,
|
||||||
|
) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
accountID := int64(0)
|
||||||
|
accountName := ""
|
||||||
|
accountType := ""
|
||||||
|
if account != nil {
|
||||||
|
accountID = account.ID
|
||||||
|
accountName = strings.TrimSpace(account.Name)
|
||||||
|
accountType = strings.TrimSpace(string(account.Type))
|
||||||
|
}
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.String("component", "service.openai_gateway"),
|
||||||
|
zap.Int64("account_id", accountID),
|
||||||
|
zap.String("account_name", accountName),
|
||||||
|
zap.String("account_type", accountType),
|
||||||
|
zap.String("request_model", strings.TrimSpace(reqModel)),
|
||||||
|
zap.String("reject_reason", strings.TrimSpace(rejectReason)),
|
||||||
|
}
|
||||||
|
fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body)
|
||||||
|
logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
@@ -1688,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
|||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
) (*OpenAIUsage, error) {
|
) (*OpenAIUsage, error) {
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2318,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||||
body, err := io.ReadAll(resp.Body)
|
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||||
|
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||||
|
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": "Upstream response too large",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2877,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
|||||||
return normalized, changed, nil
|
return normalized, changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||||
|
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||||
|
if !strings.Contains(model, "codex") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions := gjson.GetBytes(body, "instructions")
|
||||||
|
if !instructions.Exists() {
|
||||||
|
return "instructions_missing"
|
||||||
|
}
|
||||||
|
if instructions.Type != gjson.String {
|
||||||
|
return "instructions_not_string"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(instructions.String()) == "" {
|
||||||
|
return "instructions_empty"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||||
if reasoningEffort == "" {
|
if reasoningEffort == "" {
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
|
func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) {
|
||||||
logSink, restore := captureStructuredLog(t)
|
logSink, restore := captureStructuredLog(t)
|
||||||
defer restore()
|
defer restore()
|
||||||
|
|
||||||
@@ -119,7 +119,7 @@ func TestLogCodexCLIOnlyDetection_LogsBothMatchedAndRejected(t *testing.T) {
|
|||||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
|
require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求"))
|
||||||
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求"))
|
require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求"))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(rec)
|
c, _ := gin.CreateTestContext(rec)
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
|
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
|||||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||||
}, body)
|
}, body)
|
||||||
|
|
||||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0"))
|
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
|
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
|
||||||
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
|
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
|
||||||
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
|
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
|||||||
c.Request.Header.Set("Proxy-Authorization", "Basic abc")
|
c.Request.Header.Set("Proxy-Authorization", "Basic abc")
|
||||||
c.Request.Header.Set("X-Test", "keep")
|
c.Request.Header.Set("X-Test", "keep")
|
||||||
|
|
||||||
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
|
||||||
upstreamSSE := strings.Join([]string{
|
upstreamSSE := strings.Join([]string{
|
||||||
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
|
`data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`,
|
||||||
@@ -211,6 +211,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
|||||||
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
|
// 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。
|
||||||
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
|
||||||
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||||
|
require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String()))
|
||||||
// 其余关键字段保持原值。
|
// 其余关键字段保持原值。
|
||||||
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
|
require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||||
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
|
require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String())
|
||||||
@@ -235,6 +236,59 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormali
|
|||||||
require.NotContains(t, body, "\"name\":\"edit\"")
|
require.NotContains(t, body, "\"name\":\"edit\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
logSink, restore := captureStructuredLog(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||||
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||||
|
|
||||||
|
// Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。
|
||||||
|
originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
|
||||||
|
upstream := &httpUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||||
|
Extra: map[string]any{"openai_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
require.Contains(t, rec.Body.String(), "requires a non-empty instructions field")
|
||||||
|
require.Nil(t, upstream.lastReq)
|
||||||
|
|
||||||
|
require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||||
|
require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
38
backend/internal/service/upstream_response_limit.go
Normal file
38
backend/internal/service/upstream_response_limit.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||||
|
|
||||||
|
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||||
|
|
||||||
|
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||||
|
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||||
|
return cfg.Gateway.UpstreamResponseReadMaxBytes
|
||||||
|
}
|
||||||
|
return defaultUpstreamResponseReadMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
|
||||||
|
if reader == nil {
|
||||||
|
return nil, errors.New("response body is nil")
|
||||||
|
}
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = defaultUpstreamResponseReadMaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if int64(len(body)) > maxBytes {
|
||||||
|
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
37
backend/internal/service/upstream_response_limit_test.go
Normal file
37
backend/internal/service/upstream_response_limit_test.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
|
||||||
|
t.Run("use default when config missing", func(t *testing.T) {
|
||||||
|
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("use configured value", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
|
||||||
|
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
||||||
|
t.Run("within limit", func(t *testing.T) {
|
||||||
|
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte("ok"), body)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exceeds limit", func(t *testing.T) {
|
||||||
|
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
|
||||||
|
require.Nil(t, body)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -146,6 +146,15 @@ gateway:
|
|||||||
# Max request body size in bytes (default: 100MB)
|
# Max request body size in bytes (default: 100MB)
|
||||||
# 请求体最大字节数(默认 100MB)
|
# 请求体最大字节数(默认 100MB)
|
||||||
max_body_size: 104857600
|
max_body_size: 104857600
|
||||||
|
# Max bytes to read for non-stream upstream responses (default: 8MB)
|
||||||
|
# 非流式上游响应体读取上限(默认 8MB)
|
||||||
|
upstream_response_read_max_bytes: 8388608
|
||||||
|
# Max bytes to read for proxy probe responses (default: 1MB)
|
||||||
|
# 代理探测响应体读取上限(默认 1MB)
|
||||||
|
proxy_probe_response_read_max_bytes: 1048576
|
||||||
|
# Enable Gemini upstream response header debug logs (default: false)
|
||||||
|
# 是否开启 Gemini 上游响应头调试日志(默认 false)
|
||||||
|
gemini_debug_response_headers: false
|
||||||
# Sora max request body size in bytes (0=use max_body_size)
|
# Sora max request body size in bytes (0=use max_body_size)
|
||||||
# Sora 请求体最大字节数(0=使用 max_body_size)
|
# Sora 请求体最大字节数(0=使用 max_body_size)
|
||||||
sora_max_body_size: 268435456
|
sora_max_body_size: 268435456
|
||||||
|
|||||||
@@ -39,16 +39,6 @@ watch(
|
|||||||
{ immediate: true }
|
{ immediate: true }
|
||||||
)
|
)
|
||||||
|
|
||||||
watch(
|
|
||||||
() => appStore.siteName,
|
|
||||||
(newName) => {
|
|
||||||
if (newName) {
|
|
||||||
document.title = `${newName} - AI API Gateway`
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{ immediate: true }
|
|
||||||
)
|
|
||||||
|
|
||||||
// Watch for authentication state and manage subscription data
|
// Watch for authentication state and manage subscription data
|
||||||
watch(
|
watch(
|
||||||
() => authStore.isAuthenticated,
|
() => authStore.isAuthenticated,
|
||||||
|
|||||||
@@ -58,12 +58,16 @@ describe('ImportDataModal', () => {
|
|||||||
|
|
||||||
const input = wrapper.find('input[type="file"]')
|
const input = wrapper.find('input[type="file"]')
|
||||||
const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
|
const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
|
||||||
|
Object.defineProperty(file, 'text', {
|
||||||
|
value: () => Promise.resolve('invalid json')
|
||||||
|
})
|
||||||
Object.defineProperty(input.element, 'files', {
|
Object.defineProperty(input.element, 'files', {
|
||||||
value: [file]
|
value: [file]
|
||||||
})
|
})
|
||||||
|
|
||||||
await input.trigger('change')
|
await input.trigger('change')
|
||||||
await wrapper.find('form').trigger('submit')
|
await wrapper.find('form').trigger('submit')
|
||||||
|
await Promise.resolve()
|
||||||
|
|
||||||
expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed')
|
expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed')
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -58,12 +58,16 @@ describe('Proxy ImportDataModal', () => {
|
|||||||
|
|
||||||
const input = wrapper.find('input[type="file"]')
|
const input = wrapper.find('input[type="file"]')
|
||||||
const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
|
const file = new File(['invalid json'], 'data.json', { type: 'application/json' })
|
||||||
|
Object.defineProperty(file, 'text', {
|
||||||
|
value: () => Promise.resolve('invalid json')
|
||||||
|
})
|
||||||
Object.defineProperty(input.element, 'files', {
|
Object.defineProperty(input.element, 'files', {
|
||||||
value: [file]
|
value: [file]
|
||||||
})
|
})
|
||||||
|
|
||||||
await input.trigger('change')
|
await input.trigger('change')
|
||||||
await wrapper.find('form').trigger('submit')
|
await wrapper.find('form').trigger('submit')
|
||||||
|
await Promise.resolve()
|
||||||
|
|
||||||
expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed')
|
expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed')
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -164,10 +164,10 @@ export async function getUsage(id: number): Promise<AccountUsageInfo> {
|
|||||||
/**
|
/**
|
||||||
* Clear account rate limit status
|
* Clear account rate limit status
|
||||||
* @param id - Account ID
|
* @param id - Account ID
|
||||||
* @returns Success confirmation
|
* @returns Updated account
|
||||||
*/
|
*/
|
||||||
export async function clearRateLimit(id: number): Promise<{ message: string }> {
|
export async function clearRateLimit(id: number): Promise<Account> {
|
||||||
const { data } = await apiClient.post<{ message: string }>(
|
const { data } = await apiClient.post<Account>(
|
||||||
`/admin/accounts/${id}/clear-rate-limit`
|
`/admin/accounts/${id}/clear-rate-limit`
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -209,7 +209,7 @@
|
|||||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
<div
|
<div
|
||||||
v-for="(mapping, index) in modelMappings"
|
v-for="(mapping, index) in modelMappings"
|
||||||
:key="index"
|
:key="getModelMappingKey(mapping)"
|
||||||
class="flex items-center gap-2"
|
class="flex items-center gap-2"
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
@@ -654,6 +654,7 @@ import Select from '@/components/common/Select.vue'
|
|||||||
import ProxySelector from '@/components/common/ProxySelector.vue'
|
import ProxySelector from '@/components/common/ProxySelector.vue'
|
||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
show: boolean
|
show: boolean
|
||||||
@@ -695,6 +696,7 @@ const baseUrl = ref('')
|
|||||||
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
|
||||||
const allowedModels = ref<string[]>([])
|
const allowedModels = ref<string[]>([])
|
||||||
const modelMappings = ref<ModelMapping[]>([])
|
const modelMappings = ref<ModelMapping[]>([])
|
||||||
|
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('bulk-model-mapping')
|
||||||
const selectedErrorCodes = ref<number[]>([])
|
const selectedErrorCodes = ref<number[]>([])
|
||||||
const customErrorCodeInput = ref<number | null>(null)
|
const customErrorCodeInput = ref<number | null>(null)
|
||||||
const interceptWarmupRequests = ref(false)
|
const interceptWarmupRequests = ref(false)
|
||||||
|
|||||||
@@ -714,7 +714,7 @@
|
|||||||
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
<div
|
<div
|
||||||
v-for="(mapping, index) in antigravityModelMappings"
|
v-for="(mapping, index) in antigravityModelMappings"
|
||||||
:key="index"
|
:key="getAntigravityModelMappingKey(mapping)"
|
||||||
class="space-y-1"
|
class="space-y-1"
|
||||||
>
|
>
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
@@ -966,7 +966,7 @@
|
|||||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
<div
|
<div
|
||||||
v-for="(mapping, index) in modelMappings"
|
v-for="(mapping, index) in modelMappings"
|
||||||
:key="index"
|
:key="getModelMappingKey(mapping)"
|
||||||
class="flex items-center gap-2"
|
class="flex items-center gap-2"
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
@@ -1225,7 +1225,7 @@
|
|||||||
<div v-if="tempUnschedRules.length > 0" class="space-y-3">
|
<div v-if="tempUnschedRules.length > 0" class="space-y-3">
|
||||||
<div
|
<div
|
||||||
v-for="(rule, index) in tempUnschedRules"
|
v-for="(rule, index) in tempUnschedRules"
|
||||||
:key="index"
|
:key="getTempUnschedRuleKey(rule)"
|
||||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="mb-2 flex items-center justify-between">
|
<div class="mb-2 flex items-center justify-between">
|
||||||
@@ -2097,6 +2097,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
|
|||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue'
|
||||||
|
|
||||||
// Type for exposed OAuthAuthorizationFlow component
|
// Type for exposed OAuthAuthorizationFlow component
|
||||||
@@ -2227,6 +2228,9 @@ const antigravityModelMappings = ref<ModelMapping[]>([])
|
|||||||
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity'))
|
||||||
const tempUnschedEnabled = ref(false)
|
const tempUnschedEnabled = ref(false)
|
||||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||||
|
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-model-mapping')
|
||||||
|
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('create-antigravity-model-mapping')
|
||||||
|
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('create-temp-unsched-rule')
|
||||||
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
|
||||||
const geminiAIStudioOAuthEnabled = ref(false)
|
const geminiAIStudioOAuthEnabled = ref(false)
|
||||||
|
|
||||||
|
|||||||
@@ -169,7 +169,7 @@
|
|||||||
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
<div v-if="modelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
<div
|
<div
|
||||||
v-for="(mapping, index) in modelMappings"
|
v-for="(mapping, index) in modelMappings"
|
||||||
:key="index"
|
:key="getModelMappingKey(mapping)"
|
||||||
class="flex items-center gap-2"
|
class="flex items-center gap-2"
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
@@ -417,7 +417,7 @@
|
|||||||
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
<div v-if="antigravityModelMappings.length > 0" class="mb-3 space-y-2">
|
||||||
<div
|
<div
|
||||||
v-for="(mapping, index) in antigravityModelMappings"
|
v-for="(mapping, index) in antigravityModelMappings"
|
||||||
:key="index"
|
:key="getAntigravityModelMappingKey(mapping)"
|
||||||
class="space-y-1"
|
class="space-y-1"
|
||||||
>
|
>
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
@@ -542,7 +542,7 @@
|
|||||||
<div v-if="tempUnschedRules.length > 0" class="space-y-3">
|
<div v-if="tempUnschedRules.length > 0" class="space-y-3">
|
||||||
<div
|
<div
|
||||||
v-for="(rule, index) in tempUnschedRules"
|
v-for="(rule, index) in tempUnschedRules"
|
||||||
:key="index"
|
:key="getTempUnschedRuleKey(rule)"
|
||||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="mb-2 flex items-center justify-between">
|
<div class="mb-2 flex items-center justify-between">
|
||||||
@@ -1093,6 +1093,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue'
|
|||||||
import GroupSelector from '@/components/common/GroupSelector.vue'
|
import GroupSelector from '@/components/common/GroupSelector.vue'
|
||||||
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
|
||||||
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
|
||||||
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
import {
|
import {
|
||||||
getPresetMappingsByPlatform,
|
getPresetMappingsByPlatform,
|
||||||
commonErrorCodes,
|
commonErrorCodes,
|
||||||
@@ -1110,7 +1111,7 @@ interface Props {
|
|||||||
const props = defineProps<Props>()
|
const props = defineProps<Props>()
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
close: []
|
close: []
|
||||||
updated: []
|
updated: [account: Account]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
@@ -1158,6 +1159,9 @@ const antigravityWhitelistModels = ref<string[]>([])
|
|||||||
const antigravityModelMappings = ref<ModelMapping[]>([])
|
const antigravityModelMappings = ref<ModelMapping[]>([])
|
||||||
const tempUnschedEnabled = ref(false)
|
const tempUnschedEnabled = ref(false)
|
||||||
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
const tempUnschedRules = ref<TempUnschedRuleForm[]>([])
|
||||||
|
const getModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-model-mapping')
|
||||||
|
const getAntigravityModelMappingKey = createStableObjectKeyResolver<ModelMapping>('edit-antigravity-model-mapping')
|
||||||
|
const getTempUnschedRuleKey = createStableObjectKeyResolver<TempUnschedRuleForm>('edit-temp-unsched-rule')
|
||||||
|
|
||||||
// Mixed channel warning dialog state
|
// Mixed channel warning dialog state
|
||||||
const showMixedChannelWarning = ref(false)
|
const showMixedChannelWarning = ref(false)
|
||||||
@@ -1845,9 +1849,9 @@ const handleSubmit = async () => {
|
|||||||
updatePayload.extra = newExtra
|
updatePayload.extra = newExtra
|
||||||
}
|
}
|
||||||
|
|
||||||
await adminAPI.accounts.update(props.account.id, updatePayload)
|
const updatedAccount = await adminAPI.accounts.update(props.account.id, updatePayload)
|
||||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||||
emit('updated')
|
emit('updated', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
// Handle 409 mixed_channel_warning - show confirmation dialog
|
// Handle 409 mixed_channel_warning - show confirmation dialog
|
||||||
@@ -1875,9 +1879,9 @@ const handleMixedChannelConfirm = async () => {
|
|||||||
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
|
pendingUpdatePayload.value.confirm_mixed_channel_risk = true
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
|
const updatedAccount = await adminAPI.accounts.update(props.account.id, pendingUpdatePayload.value)
|
||||||
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
appStore.showSuccess(t('admin.accounts.accountUpdated'))
|
||||||
emit('updated')
|
emit('updated', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate'))
|
||||||
|
|||||||
@@ -143,6 +143,24 @@ const handleClose = () => {
|
|||||||
emit('close')
|
emit('close')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const readFileAsText = async (sourceFile: File): Promise<string> => {
|
||||||
|
if (typeof sourceFile.text === 'function') {
|
||||||
|
return sourceFile.text()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof sourceFile.arrayBuffer === 'function') {
|
||||||
|
const buffer = await sourceFile.arrayBuffer()
|
||||||
|
return new TextDecoder().decode(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return await new Promise<string>((resolve, reject) => {
|
||||||
|
const reader = new FileReader()
|
||||||
|
reader.onload = () => resolve(String(reader.result ?? ''))
|
||||||
|
reader.onerror = () => reject(reader.error || new Error('Failed to read file'))
|
||||||
|
reader.readAsText(sourceFile)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const handleImport = async () => {
|
const handleImport = async () => {
|
||||||
if (!file.value) {
|
if (!file.value) {
|
||||||
appStore.showError(t('admin.accounts.dataImportSelectFile'))
|
appStore.showError(t('admin.accounts.dataImportSelectFile'))
|
||||||
@@ -151,7 +169,7 @@ const handleImport = async () => {
|
|||||||
|
|
||||||
importing.value = true
|
importing.value = true
|
||||||
try {
|
try {
|
||||||
const text = await file.value.text()
|
const text = await readFileAsText(file.value)
|
||||||
const dataPayload = JSON.parse(text)
|
const dataPayload = JSON.parse(text)
|
||||||
|
|
||||||
const res = await adminAPI.accounts.importData({
|
const res = await adminAPI.accounts.importData({
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ interface Props {
|
|||||||
const props = defineProps<Props>()
|
const props = defineProps<Props>()
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
close: []
|
close: []
|
||||||
reauthorized: []
|
reauthorized: [account: Account]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
@@ -370,10 +370,10 @@ const handleExchangeCode = async () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Clear error status after successful re-authorization
|
// Clear error status after successful re-authorization
|
||||||
await adminAPI.accounts.clearError(props.account.id)
|
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||||
|
|
||||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||||
emit('reauthorized')
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
@@ -404,9 +404,9 @@ const handleExchangeCode = async () => {
|
|||||||
type: 'oauth',
|
type: 'oauth',
|
||||||
credentials
|
credentials
|
||||||
})
|
})
|
||||||
await adminAPI.accounts.clearError(props.account.id)
|
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||||
emit('reauthorized')
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
geminiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
@@ -436,9 +436,9 @@ const handleExchangeCode = async () => {
|
|||||||
type: 'oauth',
|
type: 'oauth',
|
||||||
credentials
|
credentials
|
||||||
})
|
})
|
||||||
await adminAPI.accounts.clearError(props.account.id)
|
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||||
emit('reauthorized')
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
antigravityOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
@@ -475,10 +475,10 @@ const handleExchangeCode = async () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Clear error status after successful re-authorization
|
// Clear error status after successful re-authorization
|
||||||
await adminAPI.accounts.clearError(props.account.id)
|
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||||
|
|
||||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||||
emit('reauthorized')
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
claudeOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed')
|
||||||
@@ -518,10 +518,10 @@ const handleCookieAuth = async (sessionKey: string) => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Clear error status after successful re-authorization
|
// Clear error status after successful re-authorization
|
||||||
await adminAPI.accounts.clearError(props.account.id)
|
const updatedAccount = await adminAPI.accounts.clearError(props.account.id)
|
||||||
|
|
||||||
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
appStore.showSuccess(t('admin.accounts.reAuthorizedSuccess'))
|
||||||
emit('reauthorized')
|
emit('reauthorized', updatedAccount)
|
||||||
handleClose()
|
handleClose()
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
claudeOAuth.error.value =
|
claudeOAuth.error.value =
|
||||||
|
|||||||
@@ -143,6 +143,24 @@ const handleClose = () => {
|
|||||||
emit('close')
|
emit('close')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const readFileAsText = async (sourceFile: File): Promise<string> => {
|
||||||
|
if (typeof sourceFile.text === 'function') {
|
||||||
|
return sourceFile.text()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof sourceFile.arrayBuffer === 'function') {
|
||||||
|
const buffer = await sourceFile.arrayBuffer()
|
||||||
|
return new TextDecoder().decode(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return await new Promise<string>((resolve, reject) => {
|
||||||
|
const reader = new FileReader()
|
||||||
|
reader.onload = () => resolve(String(reader.result ?? ''))
|
||||||
|
reader.onerror = () => reject(reader.error || new Error('Failed to read file'))
|
||||||
|
reader.readAsText(sourceFile)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const handleImport = async () => {
|
const handleImport = async () => {
|
||||||
if (!file.value) {
|
if (!file.value) {
|
||||||
appStore.showError(t('admin.proxies.dataImportSelectFile'))
|
appStore.showError(t('admin.proxies.dataImportSelectFile'))
|
||||||
@@ -151,7 +169,7 @@ const handleImport = async () => {
|
|||||||
|
|
||||||
importing.value = true
|
importing.value = true
|
||||||
try {
|
try {
|
||||||
const text = await file.value.text()
|
const text = await readFileAsText(file.value)
|
||||||
const dataPayload = JSON.parse(text)
|
const dataPayload = JSON.parse(text)
|
||||||
|
|
||||||
const res = await adminAPI.proxies.importData({ data: dataPayload })
|
const res = await adminAPI.proxies.importData({ data: dataPayload })
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
<template v-if="loading">
|
<template v-if="loading">
|
||||||
<div v-for="i in 5" :key="i" class="rounded-lg border border-gray-200 bg-white p-4 dark:border-dark-700 dark:bg-dark-900">
|
<div v-for="i in 5" :key="i" class="rounded-lg border border-gray-200 bg-white p-4 dark:border-dark-700 dark:bg-dark-900">
|
||||||
<div class="space-y-3">
|
<div class="space-y-3">
|
||||||
<div v-for="column in columns.filter(c => c.key !== 'actions')" :key="column.key" class="flex justify-between">
|
<div v-for="column in dataColumns" :key="column.key" class="flex justify-between">
|
||||||
<div class="h-4 w-20 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
|
<div class="h-4 w-20 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
|
||||||
<div class="h-4 w-32 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
|
<div class="h-4 w-32 animate-pulse rounded bg-gray-200 dark:bg-dark-700"></div>
|
||||||
</div>
|
</div>
|
||||||
@@ -39,7 +39,7 @@
|
|||||||
>
|
>
|
||||||
<div class="space-y-3">
|
<div class="space-y-3">
|
||||||
<div
|
<div
|
||||||
v-for="column in columns.filter(c => c.key !== 'actions')"
|
v-for="column in dataColumns"
|
||||||
:key="column.key"
|
:key="column.key"
|
||||||
class="flex items-start justify-between gap-4"
|
class="flex items-start justify-between gap-4"
|
||||||
>
|
>
|
||||||
@@ -439,10 +439,15 @@ const resolveRowKey = (row: any, index: number) => {
|
|||||||
return key ?? index
|
return key ?? index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const dataColumns = computed(() => props.columns.filter((column) => column.key !== 'actions'))
|
||||||
|
const columnsSignature = computed(() =>
|
||||||
|
props.columns.map((column) => `${column.key}:${column.sortable ? '1' : '0'}`).join('|')
|
||||||
|
)
|
||||||
|
|
||||||
// 数据/列变化时重新检查滚动状态
|
// 数据/列变化时重新检查滚动状态
|
||||||
// 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环
|
// 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环
|
||||||
watch(
|
watch(
|
||||||
[() => props.data.length, () => props.columns],
|
[() => props.data.length, columnsSignature],
|
||||||
async () => {
|
async () => {
|
||||||
await nextTick()
|
await nextTick()
|
||||||
checkScrollable()
|
checkScrollable()
|
||||||
@@ -555,7 +560,7 @@ onMounted(() => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
watch(
|
watch(
|
||||||
() => props.columns,
|
columnsSignature,
|
||||||
() => {
|
() => {
|
||||||
// If current sort key is no longer sortable/visible, fall back to default/persisted.
|
// If current sort key is no longer sortable/visible, fall back to default/persisted.
|
||||||
const normalized = normalizeSortKey(sortKey.value)
|
const normalized = normalizeSortKey(sortKey.value)
|
||||||
@@ -575,7 +580,7 @@ watch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ deep: true }
|
{ flush: 'post' }
|
||||||
)
|
)
|
||||||
|
|
||||||
watch(
|
watch(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
<div class="relative" ref="dropdownRef">
|
<div class="relative" ref="dropdownRef">
|
||||||
<button
|
<button
|
||||||
@click="toggleDropdown"
|
@click="toggleDropdown"
|
||||||
|
:disabled="switching"
|
||||||
class="flex items-center gap-1.5 rounded-lg px-2 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 dark:text-gray-300 dark:hover:bg-dark-700"
|
class="flex items-center gap-1.5 rounded-lg px-2 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 dark:text-gray-300 dark:hover:bg-dark-700"
|
||||||
:title="currentLocale?.name"
|
:title="currentLocale?.name"
|
||||||
>
|
>
|
||||||
@@ -23,6 +24,7 @@
|
|||||||
<button
|
<button
|
||||||
v-for="locale in availableLocales"
|
v-for="locale in availableLocales"
|
||||||
:key="locale.code"
|
:key="locale.code"
|
||||||
|
:disabled="switching"
|
||||||
@click="selectLocale(locale.code)"
|
@click="selectLocale(locale.code)"
|
||||||
class="flex w-full items-center gap-2 px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-100 dark:text-gray-200 dark:hover:bg-dark-700"
|
class="flex w-full items-center gap-2 px-3 py-2 text-sm text-gray-700 transition-colors hover:bg-gray-100 dark:text-gray-200 dark:hover:bg-dark-700"
|
||||||
:class="{
|
:class="{
|
||||||
@@ -49,6 +51,7 @@ const { locale } = useI18n()
|
|||||||
|
|
||||||
const isOpen = ref(false)
|
const isOpen = ref(false)
|
||||||
const dropdownRef = ref<HTMLElement | null>(null)
|
const dropdownRef = ref<HTMLElement | null>(null)
|
||||||
|
const switching = ref(false)
|
||||||
|
|
||||||
const currentLocaleCode = computed(() => locale.value)
|
const currentLocaleCode = computed(() => locale.value)
|
||||||
const currentLocale = computed(() => availableLocales.find((l) => l.code === locale.value))
|
const currentLocale = computed(() => availableLocales.find((l) => l.code === locale.value))
|
||||||
@@ -57,9 +60,18 @@ function toggleDropdown() {
|
|||||||
isOpen.value = !isOpen.value
|
isOpen.value = !isOpen.value
|
||||||
}
|
}
|
||||||
|
|
||||||
function selectLocale(code: string) {
|
async function selectLocale(code: string) {
|
||||||
setLocale(code)
|
if (switching.value || code === currentLocaleCode.value) {
|
||||||
isOpen.value = false
|
isOpen.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switching.value = true
|
||||||
|
try {
|
||||||
|
await setLocale(code)
|
||||||
|
isOpen.value = false
|
||||||
|
} finally {
|
||||||
|
switching.value = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleClickOutside(event: MouseEvent) {
|
function handleClickOutside(event: MouseEvent) {
|
||||||
|
|||||||
@@ -84,8 +84,8 @@
|
|||||||
|
|
||||||
<!-- Page numbers -->
|
<!-- Page numbers -->
|
||||||
<button
|
<button
|
||||||
v-for="pageNum in visiblePages"
|
v-for="(pageNum, index) in visiblePages"
|
||||||
:key="pageNum"
|
:key="`${pageNum}-${index}`"
|
||||||
@click="typeof pageNum === 'number' && goToPage(pageNum)"
|
@click="typeof pageNum === 'number' && goToPage(pageNum)"
|
||||||
:disabled="typeof pageNum !== 'number'"
|
:disabled="typeof pageNum !== 'number'"
|
||||||
:class="[
|
:class="[
|
||||||
|
|||||||
@@ -66,8 +66,8 @@
|
|||||||
<!-- Progress bar -->
|
<!-- Progress bar -->
|
||||||
<div v-if="toast.duration" class="h-1 bg-gray-100 dark:bg-dark-700">
|
<div v-if="toast.duration" class="h-1 bg-gray-100 dark:bg-dark-700">
|
||||||
<div
|
<div
|
||||||
:class="['h-full transition-all', getProgressBarColor(toast.type)]"
|
:class="['h-full toast-progress', getProgressBarColor(toast.type)]"
|
||||||
:style="{ width: `${getProgress(toast)}%` }"
|
:style="{ animationDuration: `${toast.duration}ms` }"
|
||||||
></div>
|
></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -77,7 +77,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed, onMounted, onUnmounted } from 'vue'
|
import { computed } from 'vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
|
|
||||||
@@ -129,36 +129,25 @@ const getProgressBarColor = (type: string): string => {
|
|||||||
return colors[type] || colors.info
|
return colors[type] || colors.info
|
||||||
}
|
}
|
||||||
|
|
||||||
const getProgress = (toast: any): number => {
|
|
||||||
if (!toast.duration || !toast.startTime) return 100
|
|
||||||
const elapsed = Date.now() - toast.startTime
|
|
||||||
const progress = Math.max(0, 100 - (elapsed / toast.duration) * 100)
|
|
||||||
return progress
|
|
||||||
}
|
|
||||||
|
|
||||||
const removeToast = (id: string) => {
|
const removeToast = (id: string) => {
|
||||||
appStore.hideToast(id)
|
appStore.hideToast(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
let intervalId: number | undefined
|
|
||||||
|
|
||||||
onMounted(() => {
|
|
||||||
// Check for expired toasts every 100ms
|
|
||||||
intervalId = window.setInterval(() => {
|
|
||||||
const now = Date.now()
|
|
||||||
toasts.value.forEach((toast) => {
|
|
||||||
if (toast.duration && toast.startTime) {
|
|
||||||
if (now - toast.startTime >= toast.duration) {
|
|
||||||
removeToast(toast.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}, 100)
|
|
||||||
})
|
|
||||||
|
|
||||||
onUnmounted(() => {
|
|
||||||
if (intervalId !== undefined) {
|
|
||||||
clearInterval(intervalId)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.toast-progress {
|
||||||
|
width: 100%;
|
||||||
|
animation-name: toast-progress-shrink;
|
||||||
|
animation-timing-function: linear;
|
||||||
|
animation-fill-mode: forwards;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes toast-progress-shrink {
|
||||||
|
from {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
to {
|
||||||
|
width: 0%;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|||||||
@@ -143,7 +143,7 @@
|
|||||||
<!-- Options (for select/multi_select) -->
|
<!-- Options (for select/multi_select) -->
|
||||||
<div v-if="form.type === 'select' || form.type === 'multi_select'" class="space-y-2">
|
<div v-if="form.type === 'select' || form.type === 'multi_select'" class="space-y-2">
|
||||||
<label class="input-label">{{ t('admin.users.attributes.options') }}</label>
|
<label class="input-label">{{ t('admin.users.attributes.options') }}</label>
|
||||||
<div v-for="(option, index) in form.options" :key="index" class="flex items-center gap-2">
|
<div v-for="(option, index) in form.options" :key="getOptionKey(option)" class="flex items-center gap-2">
|
||||||
<input
|
<input
|
||||||
v-model="option.value"
|
v-model="option.value"
|
||||||
type="text"
|
type="text"
|
||||||
@@ -246,6 +246,7 @@ import BaseDialog from '@/components/common/BaseDialog.vue'
|
|||||||
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import Select from '@/components/common/Select.vue'
|
import Select from '@/components/common/Select.vue'
|
||||||
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
@@ -270,6 +271,7 @@ const showEditModal = ref(false)
|
|||||||
const showDeleteDialog = ref(false)
|
const showDeleteDialog = ref(false)
|
||||||
const editingAttribute = ref<UserAttributeDefinition | null>(null)
|
const editingAttribute = ref<UserAttributeDefinition | null>(null)
|
||||||
const deletingAttribute = ref<UserAttributeDefinition | null>(null)
|
const deletingAttribute = ref<UserAttributeDefinition | null>(null)
|
||||||
|
const getOptionKey = createStableObjectKeyResolver<UserAttributeOption>('user-attr-option')
|
||||||
|
|
||||||
const form = reactive({
|
const form = reactive({
|
||||||
key: '',
|
key: '',
|
||||||
@@ -315,7 +317,7 @@ const openEditModal = (attr: UserAttributeDefinition) => {
|
|||||||
form.placeholder = attr.placeholder || ''
|
form.placeholder = attr.placeholder || ''
|
||||||
form.required = attr.required
|
form.required = attr.required
|
||||||
form.enabled = attr.enabled
|
form.enabled = attr.enabled
|
||||||
form.options = attr.options ? [...attr.options] : []
|
form.options = attr.options ? attr.options.map((opt) => ({ ...opt })) : []
|
||||||
showEditModal.value = true
|
showEditModal.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -88,7 +88,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted, computed } from 'vue'
|
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { totpAPI } from '@/api'
|
import { totpAPI } from '@/api'
|
||||||
@@ -107,6 +107,7 @@ const loading = ref(false)
|
|||||||
const error = ref('')
|
const error = ref('')
|
||||||
const sendingCode = ref(false)
|
const sendingCode = ref(false)
|
||||||
const codeCooldown = ref(0)
|
const codeCooldown = ref(0)
|
||||||
|
const cooldownTimer = ref<ReturnType<typeof setInterval> | null>(null)
|
||||||
const form = ref({
|
const form = ref({
|
||||||
emailCode: '',
|
emailCode: '',
|
||||||
password: ''
|
password: ''
|
||||||
@@ -139,10 +140,17 @@ const handleSendCode = async () => {
|
|||||||
appStore.showSuccess(t('profile.totp.codeSent'))
|
appStore.showSuccess(t('profile.totp.codeSent'))
|
||||||
// Start cooldown
|
// Start cooldown
|
||||||
codeCooldown.value = 60
|
codeCooldown.value = 60
|
||||||
const timer = setInterval(() => {
|
if (cooldownTimer.value) {
|
||||||
|
clearInterval(cooldownTimer.value)
|
||||||
|
cooldownTimer.value = null
|
||||||
|
}
|
||||||
|
cooldownTimer.value = setInterval(() => {
|
||||||
codeCooldown.value--
|
codeCooldown.value--
|
||||||
if (codeCooldown.value <= 0) {
|
if (codeCooldown.value <= 0) {
|
||||||
clearInterval(timer)
|
if (cooldownTimer.value) {
|
||||||
|
clearInterval(cooldownTimer.value)
|
||||||
|
cooldownTimer.value = null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, 1000)
|
}, 1000)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
@@ -176,4 +184,11 @@ const handleDisable = async () => {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadVerificationMethod()
|
loadVerificationMethod()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
if (cooldownTimer.value) {
|
||||||
|
clearInterval(cooldownTimer.value)
|
||||||
|
cooldownTimer.value = null
|
||||||
|
}
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -175,7 +175,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted, nextTick, watch, computed } from 'vue'
|
import { ref, onMounted, onUnmounted, nextTick, watch, computed } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { totpAPI } from '@/api'
|
import { totpAPI } from '@/api'
|
||||||
@@ -198,6 +198,7 @@ const verifyForm = ref({ emailCode: '', password: '' })
|
|||||||
const verifyError = ref('')
|
const verifyError = ref('')
|
||||||
const sendingCode = ref(false)
|
const sendingCode = ref(false)
|
||||||
const codeCooldown = ref(0)
|
const codeCooldown = ref(0)
|
||||||
|
const cooldownTimer = ref<ReturnType<typeof setInterval> | null>(null)
|
||||||
|
|
||||||
const setupLoading = ref(false)
|
const setupLoading = ref(false)
|
||||||
const setupData = ref<TotpSetupResponse | null>(null)
|
const setupData = ref<TotpSetupResponse | null>(null)
|
||||||
@@ -338,10 +339,17 @@ const handleSendCode = async () => {
|
|||||||
appStore.showSuccess(t('profile.totp.codeSent'))
|
appStore.showSuccess(t('profile.totp.codeSent'))
|
||||||
// Start cooldown
|
// Start cooldown
|
||||||
codeCooldown.value = 60
|
codeCooldown.value = 60
|
||||||
const timer = setInterval(() => {
|
if (cooldownTimer.value) {
|
||||||
|
clearInterval(cooldownTimer.value)
|
||||||
|
cooldownTimer.value = null
|
||||||
|
}
|
||||||
|
cooldownTimer.value = setInterval(() => {
|
||||||
codeCooldown.value--
|
codeCooldown.value--
|
||||||
if (codeCooldown.value <= 0) {
|
if (codeCooldown.value <= 0) {
|
||||||
clearInterval(timer)
|
if (cooldownTimer.value) {
|
||||||
|
clearInterval(cooldownTimer.value)
|
||||||
|
cooldownTimer.value = null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, 1000)
|
}, 1000)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
@@ -397,4 +405,11 @@ const handleVerify = async () => {
|
|||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadVerificationMethod()
|
loadVerificationMethod()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
if (cooldownTimer.value) {
|
||||||
|
clearInterval(cooldownTimer.value)
|
||||||
|
cooldownTimer.value = null
|
||||||
|
}
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { mount } from '@vue/test-utils'
|
||||||
|
import TotpSetupModal from '@/components/user/profile/TotpSetupModal.vue'
|
||||||
|
import TotpDisableDialog from '@/components/user/profile/TotpDisableDialog.vue'
|
||||||
|
|
||||||
|
const mocks = vi.hoisted(() => ({
|
||||||
|
showSuccess: vi.fn(),
|
||||||
|
showError: vi.fn(),
|
||||||
|
getVerificationMethod: vi.fn(),
|
||||||
|
sendVerifyCode: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('vue-i18n', () => ({
|
||||||
|
useI18n: () => ({
|
||||||
|
t: (key: string) => key
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/stores/app', () => ({
|
||||||
|
useAppStore: () => ({
|
||||||
|
showSuccess: mocks.showSuccess,
|
||||||
|
showError: mocks.showError
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/api', () => ({
|
||||||
|
totpAPI: {
|
||||||
|
getVerificationMethod: mocks.getVerificationMethod,
|
||||||
|
sendVerifyCode: mocks.sendVerifyCode,
|
||||||
|
initiateSetup: vi.fn(),
|
||||||
|
enable: vi.fn(),
|
||||||
|
disable: vi.fn()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
const flushPromises = async () => {
|
||||||
|
await Promise.resolve()
|
||||||
|
await Promise.resolve()
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('TOTP 弹窗定时器清理', () => {
|
||||||
|
let intervalSeed = 1000
|
||||||
|
let setIntervalSpy: ReturnType<typeof vi.spyOn>
|
||||||
|
let clearIntervalSpy: ReturnType<typeof vi.spyOn>
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
intervalSeed = 1000
|
||||||
|
mocks.showSuccess.mockReset()
|
||||||
|
mocks.showError.mockReset()
|
||||||
|
mocks.getVerificationMethod.mockReset()
|
||||||
|
mocks.sendVerifyCode.mockReset()
|
||||||
|
|
||||||
|
mocks.getVerificationMethod.mockResolvedValue({ method: 'email' })
|
||||||
|
mocks.sendVerifyCode.mockResolvedValue({ success: true })
|
||||||
|
|
||||||
|
setIntervalSpy = vi.spyOn(window, 'setInterval').mockImplementation(((handler: TimerHandler) => {
|
||||||
|
void handler
|
||||||
|
intervalSeed += 1
|
||||||
|
return intervalSeed as unknown as number
|
||||||
|
}) as typeof window.setInterval)
|
||||||
|
clearIntervalSpy = vi.spyOn(window, 'clearInterval')
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
setIntervalSpy.mockRestore()
|
||||||
|
clearIntervalSpy.mockRestore()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('TotpSetupModal 卸载时清理倒计时定时器', async () => {
|
||||||
|
const wrapper = mount(TotpSetupModal)
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const sendButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((button) => button.text().includes('profile.totp.sendCode'))
|
||||||
|
|
||||||
|
expect(sendButton).toBeTruthy()
|
||||||
|
await sendButton!.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(setIntervalSpy).toHaveBeenCalledTimes(1)
|
||||||
|
const timerId = setIntervalSpy.mock.results[0]?.value
|
||||||
|
|
||||||
|
wrapper.unmount()
|
||||||
|
|
||||||
|
expect(clearIntervalSpy).toHaveBeenCalledWith(timerId)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('TotpDisableDialog 卸载时清理倒计时定时器', async () => {
|
||||||
|
const wrapper = mount(TotpDisableDialog)
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const sendButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((button) => button.text().includes('profile.totp.sendCode'))
|
||||||
|
|
||||||
|
expect(sendButton).toBeTruthy()
|
||||||
|
await sendButton!.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(setIntervalSpy).toHaveBeenCalledTimes(1)
|
||||||
|
const timerId = setIntervalSpy.mock.results[0]?.value
|
||||||
|
|
||||||
|
wrapper.unmount()
|
||||||
|
|
||||||
|
expect(clearIntervalSpy).toHaveBeenCalledWith(timerId)
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
|
||||||
|
|
||||||
|
const flushPromises = () => Promise.resolve()
|
||||||
|
|
||||||
|
describe('useKeyedDebouncedSearch', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.useFakeTimers()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.useRealTimers()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('为不同 key 独立防抖触发搜索', async () => {
|
||||||
|
const search = vi.fn().mockResolvedValue([])
|
||||||
|
const onSuccess = vi.fn()
|
||||||
|
|
||||||
|
const searcher = useKeyedDebouncedSearch<string[]>({
|
||||||
|
delay: 100,
|
||||||
|
search,
|
||||||
|
onSuccess
|
||||||
|
})
|
||||||
|
|
||||||
|
searcher.trigger('a', 'foo')
|
||||||
|
searcher.trigger('b', 'bar')
|
||||||
|
|
||||||
|
expect(search).not.toHaveBeenCalled()
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(100)
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(search).toHaveBeenCalledTimes(2)
|
||||||
|
expect(search).toHaveBeenNthCalledWith(
|
||||||
|
1,
|
||||||
|
'foo',
|
||||||
|
expect.objectContaining({ key: 'a', signal: expect.any(AbortSignal) })
|
||||||
|
)
|
||||||
|
expect(search).toHaveBeenNthCalledWith(
|
||||||
|
2,
|
||||||
|
'bar',
|
||||||
|
expect.objectContaining({ key: 'b', signal: expect.any(AbortSignal) })
|
||||||
|
)
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('同 key 新请求会取消旧请求并忽略过期响应', async () => {
|
||||||
|
const resolves: Array<(value: string[]) => void> = []
|
||||||
|
const search = vi.fn().mockImplementation(
|
||||||
|
() => new Promise<string[]>((resolve) => {
|
||||||
|
resolves.push(resolve)
|
||||||
|
})
|
||||||
|
)
|
||||||
|
const onSuccess = vi.fn()
|
||||||
|
|
||||||
|
const searcher = useKeyedDebouncedSearch<string[]>({
|
||||||
|
delay: 50,
|
||||||
|
search,
|
||||||
|
onSuccess
|
||||||
|
})
|
||||||
|
|
||||||
|
searcher.trigger('rule-1', 'first')
|
||||||
|
vi.advanceTimersByTime(50)
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
searcher.trigger('rule-1', 'second')
|
||||||
|
vi.advanceTimersByTime(50)
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(search).toHaveBeenCalledTimes(2)
|
||||||
|
|
||||||
|
resolves[1](['second'])
|
||||||
|
await flushPromises()
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(1)
|
||||||
|
expect(onSuccess).toHaveBeenLastCalledWith('rule-1', ['second'])
|
||||||
|
|
||||||
|
resolves[0](['first'])
|
||||||
|
await flushPromises()
|
||||||
|
expect(onSuccess).toHaveBeenCalledTimes(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clearKey 会取消未执行任务', () => {
|
||||||
|
const search = vi.fn().mockResolvedValue([])
|
||||||
|
const onSuccess = vi.fn()
|
||||||
|
|
||||||
|
const searcher = useKeyedDebouncedSearch<string[]>({
|
||||||
|
delay: 100,
|
||||||
|
search,
|
||||||
|
onSuccess
|
||||||
|
})
|
||||||
|
|
||||||
|
searcher.trigger('a', 'foo')
|
||||||
|
searcher.clearKey('a')
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(100)
|
||||||
|
|
||||||
|
expect(search).not.toHaveBeenCalled()
|
||||||
|
expect(onSuccess).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
103
frontend/src/composables/useKeyedDebouncedSearch.ts
Normal file
103
frontend/src/composables/useKeyedDebouncedSearch.ts
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import { getCurrentInstance, onUnmounted } from 'vue'
|
||||||
|
|
||||||
|
export interface KeyedDebouncedSearchContext {
|
||||||
|
key: string
|
||||||
|
signal: AbortSignal
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseKeyedDebouncedSearchOptions<T> {
|
||||||
|
delay?: number
|
||||||
|
search: (keyword: string, context: KeyedDebouncedSearchContext) => Promise<T>
|
||||||
|
onSuccess: (key: string, result: T) => void
|
||||||
|
onError?: (key: string, error: unknown) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 多实例隔离的防抖搜索:每个 key 有独立的防抖、请求取消与过期响应保护。
|
||||||
|
*/
|
||||||
|
export function useKeyedDebouncedSearch<T>(options: UseKeyedDebouncedSearchOptions<T>) {
|
||||||
|
const delay = options.delay ?? 300
|
||||||
|
const timers = new Map<string, ReturnType<typeof setTimeout>>()
|
||||||
|
const controllers = new Map<string, AbortController>()
|
||||||
|
const versions = new Map<string, number>()
|
||||||
|
|
||||||
|
const clearKey = (key: string) => {
|
||||||
|
const timer = timers.get(key)
|
||||||
|
if (timer) {
|
||||||
|
clearTimeout(timer)
|
||||||
|
timers.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
const controller = controllers.get(key)
|
||||||
|
if (controller) {
|
||||||
|
controller.abort()
|
||||||
|
controllers.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
versions.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearAll = () => {
|
||||||
|
const allKeys = new Set<string>([
|
||||||
|
...timers.keys(),
|
||||||
|
...controllers.keys(),
|
||||||
|
...versions.keys()
|
||||||
|
])
|
||||||
|
|
||||||
|
allKeys.forEach((key) => clearKey(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
const trigger = (key: string, keyword: string) => {
|
||||||
|
const nextVersion = (versions.get(key) ?? 0) + 1
|
||||||
|
versions.set(key, nextVersion)
|
||||||
|
|
||||||
|
const existingTimer = timers.get(key)
|
||||||
|
if (existingTimer) {
|
||||||
|
clearTimeout(existingTimer)
|
||||||
|
timers.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
const inFlight = controllers.get(key)
|
||||||
|
if (inFlight) {
|
||||||
|
inFlight.abort()
|
||||||
|
controllers.delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
const timer = setTimeout(async () => {
|
||||||
|
timers.delete(key)
|
||||||
|
|
||||||
|
const controller = new AbortController()
|
||||||
|
controllers.set(key, controller)
|
||||||
|
const requestVersion = versions.get(key)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await options.search(keyword, { key, signal: controller.signal })
|
||||||
|
if (controller.signal.aborted) return
|
||||||
|
if (versions.get(key) !== requestVersion) return
|
||||||
|
options.onSuccess(key, result)
|
||||||
|
} catch (error) {
|
||||||
|
if (controller.signal.aborted) return
|
||||||
|
if (versions.get(key) !== requestVersion) return
|
||||||
|
options.onError?.(key, error)
|
||||||
|
} finally {
|
||||||
|
if (controllers.get(key) === controller) {
|
||||||
|
controllers.delete(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, delay)
|
||||||
|
|
||||||
|
timers.set(key, timer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getCurrentInstance()) {
|
||||||
|
onUnmounted(() => {
|
||||||
|
clearAll()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
trigger,
|
||||||
|
clearKey,
|
||||||
|
clearAll
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,53 +1,83 @@
|
|||||||
import { createI18n } from 'vue-i18n'
|
import { createI18n } from 'vue-i18n'
|
||||||
import en from './locales/en'
|
|
||||||
import zh from './locales/zh'
|
type LocaleCode = 'en' | 'zh'
|
||||||
|
|
||||||
|
type LocaleMessages = Record<string, any>
|
||||||
|
|
||||||
const LOCALE_KEY = 'sub2api_locale'
|
const LOCALE_KEY = 'sub2api_locale'
|
||||||
|
const DEFAULT_LOCALE: LocaleCode = 'en'
|
||||||
|
|
||||||
function getDefaultLocale(): string {
|
const localeLoaders: Record<LocaleCode, () => Promise<{ default: LocaleMessages }>> = {
|
||||||
// Check localStorage first
|
en: () => import('./locales/en'),
|
||||||
|
zh: () => import('./locales/zh')
|
||||||
|
}
|
||||||
|
|
||||||
|
function isLocaleCode(value: string): value is LocaleCode {
|
||||||
|
return value === 'en' || value === 'zh'
|
||||||
|
}
|
||||||
|
|
||||||
|
function getDefaultLocale(): LocaleCode {
|
||||||
const saved = localStorage.getItem(LOCALE_KEY)
|
const saved = localStorage.getItem(LOCALE_KEY)
|
||||||
if (saved && ['en', 'zh'].includes(saved)) {
|
if (saved && isLocaleCode(saved)) {
|
||||||
return saved
|
return saved
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check browser language
|
|
||||||
const browserLang = navigator.language.toLowerCase()
|
const browserLang = navigator.language.toLowerCase()
|
||||||
if (browserLang.startsWith('zh')) {
|
if (browserLang.startsWith('zh')) {
|
||||||
return 'zh'
|
return 'zh'
|
||||||
}
|
}
|
||||||
|
|
||||||
return 'en'
|
return DEFAULT_LOCALE
|
||||||
}
|
}
|
||||||
|
|
||||||
export const i18n = createI18n({
|
export const i18n = createI18n({
|
||||||
legacy: false,
|
legacy: false,
|
||||||
locale: getDefaultLocale(),
|
locale: getDefaultLocale(),
|
||||||
fallbackLocale: 'en',
|
fallbackLocale: DEFAULT_LOCALE,
|
||||||
messages: {
|
messages: {},
|
||||||
en,
|
|
||||||
zh
|
|
||||||
},
|
|
||||||
// 禁用 HTML 消息警告 - 引导步骤使用富文本内容(driver.js 支持 HTML)
|
// 禁用 HTML 消息警告 - 引导步骤使用富文本内容(driver.js 支持 HTML)
|
||||||
// 这些内容是内部定义的,不存在 XSS 风险
|
// 这些内容是内部定义的,不存在 XSS 风险
|
||||||
warnHtmlMessage: false
|
warnHtmlMessage: false
|
||||||
})
|
})
|
||||||
|
|
||||||
export function setLocale(locale: string) {
|
const loadedLocales = new Set<LocaleCode>()
|
||||||
if (['en', 'zh'].includes(locale)) {
|
|
||||||
i18n.global.locale.value = locale as 'en' | 'zh'
|
export async function loadLocaleMessages(locale: LocaleCode): Promise<void> {
|
||||||
localStorage.setItem(LOCALE_KEY, locale)
|
if (loadedLocales.has(locale)) {
|
||||||
document.documentElement.setAttribute('lang', locale)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const loader = localeLoaders[locale]
|
||||||
|
const module = await loader()
|
||||||
|
i18n.global.setLocaleMessage(locale, module.default)
|
||||||
|
loadedLocales.add(locale)
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getLocale(): string {
|
export async function initI18n(): Promise<void> {
|
||||||
return i18n.global.locale.value
|
const current = getLocale()
|
||||||
|
await loadLocaleMessages(current)
|
||||||
|
document.documentElement.setAttribute('lang', current)
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function setLocale(locale: string): Promise<void> {
|
||||||
|
if (!isLocaleCode(locale)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
await loadLocaleMessages(locale)
|
||||||
|
i18n.global.locale.value = locale
|
||||||
|
localStorage.setItem(LOCALE_KEY, locale)
|
||||||
|
document.documentElement.setAttribute('lang', locale)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getLocale(): LocaleCode {
|
||||||
|
const current = i18n.global.locale.value
|
||||||
|
return isLocaleCode(current) ? current : DEFAULT_LOCALE
|
||||||
}
|
}
|
||||||
|
|
||||||
export const availableLocales = [
|
export const availableLocales = [
|
||||||
{ code: 'en', name: 'English', flag: '🇺🇸' },
|
{ code: 'en', name: 'English', flag: '🇺🇸' },
|
||||||
{ code: 'zh', name: '中文', flag: '🇨🇳' }
|
{ code: 'zh', name: '中文', flag: '🇨🇳' }
|
||||||
]
|
] as const
|
||||||
|
|
||||||
export default i18n
|
export default i18n
|
||||||
|
|||||||
@@ -2,28 +2,33 @@ import { createApp } from 'vue'
|
|||||||
import { createPinia } from 'pinia'
|
import { createPinia } from 'pinia'
|
||||||
import App from './App.vue'
|
import App from './App.vue'
|
||||||
import router from './router'
|
import router from './router'
|
||||||
import i18n from './i18n'
|
import i18n, { initI18n } from './i18n'
|
||||||
|
import { useAppStore } from '@/stores/app'
|
||||||
import './style.css'
|
import './style.css'
|
||||||
|
|
||||||
const app = createApp(App)
|
async function bootstrap() {
|
||||||
const pinia = createPinia()
|
const app = createApp(App)
|
||||||
app.use(pinia)
|
const pinia = createPinia()
|
||||||
|
app.use(pinia)
|
||||||
|
|
||||||
// Initialize settings from injected config BEFORE mounting (prevents flash)
|
// Initialize settings from injected config BEFORE mounting (prevents flash)
|
||||||
// This must happen after pinia is installed but before router and i18n
|
// This must happen after pinia is installed but before router and i18n
|
||||||
import { useAppStore } from '@/stores/app'
|
const appStore = useAppStore()
|
||||||
const appStore = useAppStore()
|
appStore.initFromInjectedConfig()
|
||||||
appStore.initFromInjectedConfig()
|
|
||||||
|
|
||||||
// Set document title immediately after config is loaded
|
// Set document title immediately after config is loaded
|
||||||
if (appStore.siteName && appStore.siteName !== 'Sub2API') {
|
if (appStore.siteName && appStore.siteName !== 'Sub2API') {
|
||||||
document.title = `${appStore.siteName} - AI API Gateway`
|
document.title = `${appStore.siteName} - AI API Gateway`
|
||||||
|
}
|
||||||
|
|
||||||
|
await initI18n()
|
||||||
|
|
||||||
|
app.use(router)
|
||||||
|
app.use(i18n)
|
||||||
|
|
||||||
|
// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染
|
||||||
|
await router.isReady()
|
||||||
|
app.mount('#app')
|
||||||
}
|
}
|
||||||
|
|
||||||
app.use(router)
|
bootstrap()
|
||||||
app.use(i18n)
|
|
||||||
|
|
||||||
// 等待路由器完成初始导航后再挂载,避免竞态条件导致的空白渲染
|
|
||||||
router.isReady().then(() => {
|
|
||||||
app.mount('#app')
|
|
||||||
})
|
|
||||||
|
|||||||
25
frontend/src/router/__tests__/title.spec.ts
Normal file
25
frontend/src/router/__tests__/title.spec.ts
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
import { resolveDocumentTitle } from '@/router/title'
|
||||||
|
|
||||||
|
describe('resolveDocumentTitle', () => {
|
||||||
|
it('路由存在标题时,使用“路由标题 - 站点名”格式', () => {
|
||||||
|
expect(resolveDocumentTitle('Usage Records', 'My Site')).toBe('Usage Records - My Site')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('路由无标题时,回退到站点名', () => {
|
||||||
|
expect(resolveDocumentTitle(undefined, 'My Site')).toBe('My Site')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('站点名为空时,回退默认站点名', () => {
|
||||||
|
expect(resolveDocumentTitle('Dashboard', '')).toBe('Dashboard - Sub2API')
|
||||||
|
expect(resolveDocumentTitle(undefined, ' ')).toBe('Sub2API')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('站点名变更时仅影响后续路由标题计算', () => {
|
||||||
|
const before = resolveDocumentTitle('Admin Dashboard', 'Alpha')
|
||||||
|
const after = resolveDocumentTitle('Admin Dashboard', 'Beta')
|
||||||
|
|
||||||
|
expect(before).toBe('Admin Dashboard - Alpha')
|
||||||
|
expect(after).toBe('Admin Dashboard - Beta')
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -8,6 +8,7 @@ import { useAuthStore } from '@/stores/auth'
|
|||||||
import { useAppStore } from '@/stores/app'
|
import { useAppStore } from '@/stores/app'
|
||||||
import { useNavigationLoadingState } from '@/composables/useNavigationLoading'
|
import { useNavigationLoadingState } from '@/composables/useNavigationLoading'
|
||||||
import { useRoutePrefetch } from '@/composables/useRoutePrefetch'
|
import { useRoutePrefetch } from '@/composables/useRoutePrefetch'
|
||||||
|
import { resolveDocumentTitle } from './title'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Route definitions with lazy loading
|
* Route definitions with lazy loading
|
||||||
@@ -389,12 +390,7 @@ router.beforeEach((to, _from, next) => {
|
|||||||
|
|
||||||
// Set page title
|
// Set page title
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
const siteName = appStore.siteName || 'Sub2API'
|
document.title = resolveDocumentTitle(to.meta.title, appStore.siteName)
|
||||||
if (to.meta.title) {
|
|
||||||
document.title = `${to.meta.title} - ${siteName}`
|
|
||||||
} else {
|
|
||||||
document.title = siteName
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if route requires authentication
|
// Check if route requires authentication
|
||||||
const requiresAuth = to.meta.requiresAuth !== false // Default to true
|
const requiresAuth = to.meta.requiresAuth !== false // Default to true
|
||||||
|
|||||||
12
frontend/src/router/title.ts
Normal file
12
frontend/src/router/title.ts
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
/**
|
||||||
|
* 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。
|
||||||
|
*/
|
||||||
|
export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string {
|
||||||
|
const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'Sub2API'
|
||||||
|
|
||||||
|
if (typeof routeTitle === 'string' && routeTitle.trim()) {
|
||||||
|
return `${routeTitle.trim()} - ${normalizedSiteName}`
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalizedSiteName
|
||||||
|
}
|
||||||
37
frontend/src/utils/__tests__/stableObjectKey.spec.ts
Normal file
37
frontend/src/utils/__tests__/stableObjectKey.spec.ts
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
|
|
||||||
|
describe('createStableObjectKeyResolver', () => {
|
||||||
|
it('对同一对象返回稳定 key', () => {
|
||||||
|
const resolve = createStableObjectKeyResolver<{ value: string }>('rule')
|
||||||
|
const obj = { value: 'a' }
|
||||||
|
|
||||||
|
const key1 = resolve(obj)
|
||||||
|
const key2 = resolve(obj)
|
||||||
|
|
||||||
|
expect(key1).toBe(key2)
|
||||||
|
expect(key1.startsWith('rule-')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('不同对象返回不同 key', () => {
|
||||||
|
const resolve = createStableObjectKeyResolver<{ value: string }>('rule')
|
||||||
|
|
||||||
|
const key1 = resolve({ value: 'a' })
|
||||||
|
const key2 = resolve({ value: 'a' })
|
||||||
|
|
||||||
|
expect(key1).not.toBe(key2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('不同 resolver 互不影响', () => {
|
||||||
|
const resolveA = createStableObjectKeyResolver<{ id: number }>('a')
|
||||||
|
const resolveB = createStableObjectKeyResolver<{ id: number }>('b')
|
||||||
|
const obj = { id: 1 }
|
||||||
|
|
||||||
|
const keyA = resolveA(obj)
|
||||||
|
const keyB = resolveB(obj)
|
||||||
|
|
||||||
|
expect(keyA).not.toBe(keyB)
|
||||||
|
expect(keyA.startsWith('a-')).toBe(true)
|
||||||
|
expect(keyB.startsWith('b-')).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
19
frontend/src/utils/stableObjectKey.ts
Normal file
19
frontend/src/utils/stableObjectKey.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
let globalStableObjectKeySeed = 0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 为对象实例生成稳定 key(基于 WeakMap,不污染业务对象)
|
||||||
|
*/
|
||||||
|
export function createStableObjectKeyResolver<T extends object>(prefix = 'item') {
|
||||||
|
const keyMap = new WeakMap<T, string>()
|
||||||
|
|
||||||
|
return (item: T): string => {
|
||||||
|
const cached = keyMap.get(item)
|
||||||
|
if (cached) {
|
||||||
|
return cached
|
||||||
|
}
|
||||||
|
|
||||||
|
const key = `${prefix}-${++globalStableObjectKeySeed}`
|
||||||
|
keyMap.set(item, key)
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -239,8 +239,8 @@
|
|||||||
<template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /></template>
|
<template #pagination><Pagination v-if="pagination.total > 0" :page="pagination.page" :total="pagination.total" :page-size="pagination.page_size" @update:page="handlePageChange" @update:pageSize="handlePageSizeChange" /></template>
|
||||||
</TablePageLayout>
|
</TablePageLayout>
|
||||||
<CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" />
|
<CreateAccountModal :show="showCreate" :proxies="proxies" :groups="groups" @close="showCreate = false" @created="reload" />
|
||||||
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="load" />
|
<EditAccountModal :show="showEdit" :account="edAcc" :proxies="proxies" :groups="groups" @close="showEdit = false" @updated="handleAccountUpdated" />
|
||||||
<ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="load" />
|
<ReAuthAccountModal :show="showReAuth" :account="reAuthAcc" @close="closeReAuthModal" @reauthorized="handleAccountUpdated" />
|
||||||
<AccountTestModal :show="showTest" :account="testingAcc" @close="closeTestModal" />
|
<AccountTestModal :show="showTest" :account="testingAcc" @close="closeTestModal" />
|
||||||
<AccountStatsModal :show="showStats" :account="statsAcc" @close="closeStatsModal" />
|
<AccountStatsModal :show="showStats" :account="statsAcc" @close="closeStatsModal" />
|
||||||
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
|
<AccountActionMenu :show="menu.show" :account="menu.acc" :position="menu.pos" @close="menu.show = false" @test="handleTest" @stats="handleViewStats" @reauth="handleReAuth" @refresh-token="handleRefresh" @reset-status="handleResetStatus" @clear-rate-limit="handleClearRateLimit" />
|
||||||
@@ -694,6 +694,53 @@ const handleBulkToggleSchedulable = async (schedulable: boolean) => {
|
|||||||
}
|
}
|
||||||
const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() }
|
const handleBulkUpdated = () => { showBulkEdit.value = false; selIds.value = []; reload() }
|
||||||
const handleDataImported = () => { showImportData.value = false; reload() }
|
const handleDataImported = () => { showImportData.value = false; reload() }
|
||||||
|
const accountMatchesCurrentFilters = (account: Account) => {
|
||||||
|
if (params.platform && account.platform !== params.platform) return false
|
||||||
|
if (params.type && account.type !== params.type) return false
|
||||||
|
if (params.status) {
|
||||||
|
if (params.status === 'rate_limited') {
|
||||||
|
if (!account.rate_limit_reset_at) return false
|
||||||
|
const resetAt = new Date(account.rate_limit_reset_at).getTime()
|
||||||
|
if (!Number.isFinite(resetAt) || resetAt <= Date.now()) return false
|
||||||
|
} else if (account.status !== params.status) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const search = String(params.search || '').trim().toLowerCase()
|
||||||
|
if (search && !account.name.toLowerCase().includes(search)) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
const mergeRuntimeFields = (oldAccount: Account, updatedAccount: Account): Account => ({
|
||||||
|
...updatedAccount,
|
||||||
|
current_concurrency: updatedAccount.current_concurrency ?? oldAccount.current_concurrency,
|
||||||
|
current_window_cost: updatedAccount.current_window_cost ?? oldAccount.current_window_cost,
|
||||||
|
active_sessions: updatedAccount.active_sessions ?? oldAccount.active_sessions
|
||||||
|
})
|
||||||
|
const patchAccountInList = (updatedAccount: Account) => {
|
||||||
|
const index = accounts.value.findIndex(account => account.id === updatedAccount.id)
|
||||||
|
if (index === -1) return
|
||||||
|
const mergedAccount = mergeRuntimeFields(accounts.value[index], updatedAccount)
|
||||||
|
if (!accountMatchesCurrentFilters(mergedAccount)) {
|
||||||
|
accounts.value = accounts.value.filter(account => account.id !== mergedAccount.id)
|
||||||
|
selIds.value = selIds.value.filter(id => id !== mergedAccount.id)
|
||||||
|
if (menu.acc?.id === mergedAccount.id) {
|
||||||
|
menu.show = false
|
||||||
|
menu.acc = null
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const nextAccounts = [...accounts.value]
|
||||||
|
nextAccounts[index] = mergedAccount
|
||||||
|
accounts.value = nextAccounts
|
||||||
|
if (edAcc.value?.id === mergedAccount.id) edAcc.value = mergedAccount
|
||||||
|
if (reAuthAcc.value?.id === mergedAccount.id) reAuthAcc.value = mergedAccount
|
||||||
|
if (tempUnschedAcc.value?.id === mergedAccount.id) tempUnschedAcc.value = mergedAccount
|
||||||
|
if (deletingAcc.value?.id === mergedAccount.id) deletingAcc.value = mergedAccount
|
||||||
|
if (menu.acc?.id === mergedAccount.id) menu.acc = mergedAccount
|
||||||
|
}
|
||||||
|
const handleAccountUpdated = (updatedAccount: Account) => {
|
||||||
|
patchAccountInList(updatedAccount)
|
||||||
|
}
|
||||||
const formatExportTimestamp = () => {
|
const formatExportTimestamp = () => {
|
||||||
const now = new Date()
|
const now = new Date()
|
||||||
const pad2 = (value: number) => String(value).padStart(2, '0')
|
const pad2 = (value: number) => String(value).padStart(2, '0')
|
||||||
@@ -743,9 +790,32 @@ const closeReAuthModal = () => { showReAuth.value = false; reAuthAcc.value = nul
|
|||||||
const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true }
|
const handleTest = (a: Account) => { testingAcc.value = a; showTest.value = true }
|
||||||
const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true }
|
const handleViewStats = (a: Account) => { statsAcc.value = a; showStats.value = true }
|
||||||
const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true }
|
const handleReAuth = (a: Account) => { reAuthAcc.value = a; showReAuth.value = true }
|
||||||
const handleRefresh = async (a: Account) => { try { await adminAPI.accounts.refreshCredentials(a.id); load() } catch (error) { console.error('Failed to refresh credentials:', error) } }
|
const handleRefresh = async (a: Account) => {
|
||||||
const handleResetStatus = async (a: Account) => { try { await adminAPI.accounts.clearError(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to reset status:', error) } }
|
try {
|
||||||
const handleClearRateLimit = async (a: Account) => { try { await adminAPI.accounts.clearRateLimit(a.id); appStore.showSuccess(t('common.success')); load() } catch (error) { console.error('Failed to clear rate limit:', error) } }
|
const updated = await adminAPI.accounts.refreshCredentials(a.id)
|
||||||
|
patchAccountInList(updated)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to refresh credentials:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const handleResetStatus = async (a: Account) => {
|
||||||
|
try {
|
||||||
|
const updated = await adminAPI.accounts.clearError(a.id)
|
||||||
|
patchAccountInList(updated)
|
||||||
|
appStore.showSuccess(t('common.success'))
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to reset status:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const handleClearRateLimit = async (a: Account) => {
|
||||||
|
try {
|
||||||
|
const updated = await adminAPI.accounts.clearRateLimit(a.id)
|
||||||
|
patchAccountInList(updated)
|
||||||
|
appStore.showSuccess(t('common.success'))
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to clear rate limit:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
|
const handleDelete = (a: Account) => { deletingAcc.value = a; showDeleteDialog.value = true }
|
||||||
const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } }
|
const confirmDelete = async () => { if(!deletingAcc.value) return; try { await adminAPI.accounts.delete(deletingAcc.value.id); showDeleteDialog.value = false; deletingAcc.value = null; reload() } catch (error) { console.error('Failed to delete account:', error) } }
|
||||||
const handleToggleSchedulable = async (a: Account) => {
|
const handleToggleSchedulable = async (a: Account) => {
|
||||||
@@ -762,7 +832,17 @@ const handleToggleSchedulable = async (a: Account) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true }
|
const handleShowTempUnsched = (a: Account) => { tempUnschedAcc.value = a; showTempUnsched.value = true }
|
||||||
const handleTempUnschedReset = async () => { if(!tempUnschedAcc.value) return; try { await adminAPI.accounts.clearError(tempUnschedAcc.value.id); showTempUnsched.value = false; tempUnschedAcc.value = null; load() } catch (error) { console.error('Failed to reset temp unscheduled:', error) } }
|
const handleTempUnschedReset = async () => {
|
||||||
|
if(!tempUnschedAcc.value) return
|
||||||
|
try {
|
||||||
|
const updated = await adminAPI.accounts.clearError(tempUnschedAcc.value.id)
|
||||||
|
showTempUnsched.value = false
|
||||||
|
tempUnschedAcc.value = null
|
||||||
|
patchAccountInList(updated)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to reset temp unscheduled:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
const formatExpiresAt = (value: number | null) => {
|
const formatExpiresAt = (value: number | null) => {
|
||||||
if (!value) return '-'
|
if (!value) return '-'
|
||||||
return formatDateTime(
|
return formatDateTime(
|
||||||
|
|||||||
@@ -759,8 +759,8 @@
|
|||||||
<!-- 路由规则列表(仅在启用时显示) -->
|
<!-- 路由规则列表(仅在启用时显示) -->
|
||||||
<div v-if="createForm.model_routing_enabled" class="space-y-3">
|
<div v-if="createForm.model_routing_enabled" class="space-y-3">
|
||||||
<div
|
<div
|
||||||
v-for="(rule, index) in createModelRoutingRules"
|
v-for="rule in createModelRoutingRules"
|
||||||
:key="index"
|
:key="getCreateRuleRenderKey(rule)"
|
||||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="flex items-start gap-3">
|
<div class="flex items-start gap-3">
|
||||||
@@ -786,7 +786,7 @@
|
|||||||
{{ account.name }}
|
{{ account.name }}
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="removeSelectedAccount(index, account.id, false)"
|
@click="removeSelectedAccount(rule, account.id)"
|
||||||
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
|
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
|
||||||
>
|
>
|
||||||
<Icon name="x" size="xs" />
|
<Icon name="x" size="xs" />
|
||||||
@@ -796,23 +796,23 @@
|
|||||||
<!-- 账号搜索输入框 -->
|
<!-- 账号搜索输入框 -->
|
||||||
<div class="relative account-search-container">
|
<div class="relative account-search-container">
|
||||||
<input
|
<input
|
||||||
v-model="accountSearchKeyword[`create-${index}`]"
|
v-model="accountSearchKeyword[getCreateRuleSearchKey(rule)]"
|
||||||
type="text"
|
type="text"
|
||||||
class="input text-sm"
|
class="input text-sm"
|
||||||
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
|
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
|
||||||
@input="searchAccounts(`create-${index}`)"
|
@input="searchAccountsByRule(rule)"
|
||||||
@focus="onAccountSearchFocus(index, false)"
|
@focus="onAccountSearchFocus(rule)"
|
||||||
/>
|
/>
|
||||||
<!-- 搜索结果下拉框 -->
|
<!-- 搜索结果下拉框 -->
|
||||||
<div
|
<div
|
||||||
v-if="showAccountDropdown[`create-${index}`] && accountSearchResults[`create-${index}`]?.length > 0"
|
v-if="showAccountDropdown[getCreateRuleSearchKey(rule)] && accountSearchResults[getCreateRuleSearchKey(rule)]?.length > 0"
|
||||||
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
||||||
>
|
>
|
||||||
<button
|
<button
|
||||||
v-for="account in accountSearchResults[`create-${index}`]"
|
v-for="account in accountSearchResults[getCreateRuleSearchKey(rule)]"
|
||||||
:key="account.id"
|
:key="account.id"
|
||||||
type="button"
|
type="button"
|
||||||
@click="selectAccount(index, account, false)"
|
@click="selectAccount(rule, account)"
|
||||||
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
|
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
|
||||||
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
|
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
|
||||||
:disabled="rule.accounts.some(a => a.id === account.id)"
|
:disabled="rule.accounts.some(a => a.id === account.id)"
|
||||||
@@ -827,7 +827,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="removeCreateRoutingRule(index)"
|
@click="removeCreateRoutingRule(rule)"
|
||||||
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
|
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
|
||||||
:title="t('admin.groups.modelRouting.removeRule')"
|
:title="t('admin.groups.modelRouting.removeRule')"
|
||||||
>
|
>
|
||||||
@@ -1439,8 +1439,8 @@
|
|||||||
<!-- 路由规则列表(仅在启用时显示) -->
|
<!-- 路由规则列表(仅在启用时显示) -->
|
||||||
<div v-if="editForm.model_routing_enabled" class="space-y-3">
|
<div v-if="editForm.model_routing_enabled" class="space-y-3">
|
||||||
<div
|
<div
|
||||||
v-for="(rule, index) in editModelRoutingRules"
|
v-for="rule in editModelRoutingRules"
|
||||||
:key="index"
|
:key="getEditRuleRenderKey(rule)"
|
||||||
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
class="rounded-lg border border-gray-200 p-3 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
<div class="flex items-start gap-3">
|
<div class="flex items-start gap-3">
|
||||||
@@ -1466,7 +1466,7 @@
|
|||||||
{{ account.name }}
|
{{ account.name }}
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="removeSelectedAccount(index, account.id, true)"
|
@click="removeSelectedAccount(rule, account.id, true)"
|
||||||
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
|
class="ml-0.5 text-primary-500 hover:text-primary-700 dark:hover:text-primary-200"
|
||||||
>
|
>
|
||||||
<Icon name="x" size="xs" />
|
<Icon name="x" size="xs" />
|
||||||
@@ -1476,23 +1476,23 @@
|
|||||||
<!-- 账号搜索输入框 -->
|
<!-- 账号搜索输入框 -->
|
||||||
<div class="relative account-search-container">
|
<div class="relative account-search-container">
|
||||||
<input
|
<input
|
||||||
v-model="accountSearchKeyword[`edit-${index}`]"
|
v-model="accountSearchKeyword[getEditRuleSearchKey(rule)]"
|
||||||
type="text"
|
type="text"
|
||||||
class="input text-sm"
|
class="input text-sm"
|
||||||
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
|
:placeholder="t('admin.groups.modelRouting.searchAccountPlaceholder')"
|
||||||
@input="searchAccounts(`edit-${index}`)"
|
@input="searchAccountsByRule(rule, true)"
|
||||||
@focus="onAccountSearchFocus(index, true)"
|
@focus="onAccountSearchFocus(rule, true)"
|
||||||
/>
|
/>
|
||||||
<!-- 搜索结果下拉框 -->
|
<!-- 搜索结果下拉框 -->
|
||||||
<div
|
<div
|
||||||
v-if="showAccountDropdown[`edit-${index}`] && accountSearchResults[`edit-${index}`]?.length > 0"
|
v-if="showAccountDropdown[getEditRuleSearchKey(rule)] && accountSearchResults[getEditRuleSearchKey(rule)]?.length > 0"
|
||||||
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
class="absolute z-50 mt-1 max-h-48 w-full overflow-auto rounded-lg border bg-white shadow-lg dark:border-dark-600 dark:bg-dark-800"
|
||||||
>
|
>
|
||||||
<button
|
<button
|
||||||
v-for="account in accountSearchResults[`edit-${index}`]"
|
v-for="account in accountSearchResults[getEditRuleSearchKey(rule)]"
|
||||||
:key="account.id"
|
:key="account.id"
|
||||||
type="button"
|
type="button"
|
||||||
@click="selectAccount(index, account, true)"
|
@click="selectAccount(rule, account, true)"
|
||||||
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
|
class="w-full px-3 py-2 text-left text-sm hover:bg-gray-100 dark:hover:bg-dark-700"
|
||||||
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
|
:class="{ 'opacity-50': rule.accounts.some(a => a.id === account.id) }"
|
||||||
:disabled="rule.accounts.some(a => a.id === account.id)"
|
:disabled="rule.accounts.some(a => a.id === account.id)"
|
||||||
@@ -1507,7 +1507,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="removeEditRoutingRule(index)"
|
@click="removeEditRoutingRule(rule)"
|
||||||
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
|
class="mt-5 p-1.5 text-gray-400 hover:text-red-500 transition-colors"
|
||||||
:title="t('admin.groups.modelRouting.removeRule')"
|
:title="t('admin.groups.modelRouting.removeRule')"
|
||||||
>
|
>
|
||||||
@@ -1687,6 +1687,8 @@ import Select from '@/components/common/Select.vue'
|
|||||||
import PlatformIcon from '@/components/common/PlatformIcon.vue'
|
import PlatformIcon from '@/components/common/PlatformIcon.vue'
|
||||||
import Icon from '@/components/icons/Icon.vue'
|
import Icon from '@/components/icons/Icon.vue'
|
||||||
import { VueDraggable } from 'vue-draggable-plus'
|
import { VueDraggable } from 'vue-draggable-plus'
|
||||||
|
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
|
||||||
|
import { useKeyedDebouncedSearch } from '@/composables/useKeyedDebouncedSearch'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
const appStore = useAppStore()
|
const appStore = useAppStore()
|
||||||
@@ -1911,33 +1913,70 @@ const createModelRoutingRules = ref<ModelRoutingRule[]>([])
|
|||||||
// 编辑表单的模型路由规则
|
// 编辑表单的模型路由规则
|
||||||
const editModelRoutingRules = ref<ModelRoutingRule[]>([])
|
const editModelRoutingRules = ref<ModelRoutingRule[]>([])
|
||||||
|
|
||||||
// 账号搜索相关状态
|
// 规则对象稳定 key(避免使用 index 导致状态错位)
|
||||||
const accountSearchKeyword = ref<Record<string, string>>({}) // 每个规则的搜索关键词 (key: "create-0" 或 "edit-0")
|
const resolveCreateRuleKey = createStableObjectKeyResolver<ModelRoutingRule>('create-rule')
|
||||||
const accountSearchResults = ref<Record<string, SimpleAccount[]>>({}) // 每个规则的搜索结果
|
const resolveEditRuleKey = createStableObjectKeyResolver<ModelRoutingRule>('edit-rule')
|
||||||
const showAccountDropdown = ref<Record<string, boolean>>({}) // 每个规则的下拉框显示状态
|
|
||||||
let accountSearchTimeout: ReturnType<typeof setTimeout> | null = null
|
|
||||||
|
|
||||||
// 搜索账号(仅限 anthropic 平台)
|
const getCreateRuleRenderKey = (rule: ModelRoutingRule) => resolveCreateRuleKey(rule)
|
||||||
const searchAccounts = async (key: string) => {
|
const getEditRuleRenderKey = (rule: ModelRoutingRule) => resolveEditRuleKey(rule)
|
||||||
if (accountSearchTimeout) clearTimeout(accountSearchTimeout)
|
|
||||||
accountSearchTimeout = setTimeout(async () => {
|
const getCreateRuleSearchKey = (rule: ModelRoutingRule) => `create-${resolveCreateRuleKey(rule)}`
|
||||||
const keyword = accountSearchKeyword.value[key] || ''
|
const getEditRuleSearchKey = (rule: ModelRoutingRule) => `edit-${resolveEditRuleKey(rule)}`
|
||||||
try {
|
|
||||||
const res = await adminAPI.accounts.list(1, 20, {
|
const getRuleSearchKey = (rule: ModelRoutingRule, isEdit: boolean = false) => {
|
||||||
|
return isEdit ? getEditRuleSearchKey(rule) : getCreateRuleSearchKey(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 账号搜索相关状态
|
||||||
|
const accountSearchKeyword = ref<Record<string, string>>({})
|
||||||
|
const accountSearchResults = ref<Record<string, SimpleAccount[]>>({})
|
||||||
|
const showAccountDropdown = ref<Record<string, boolean>>({})
|
||||||
|
|
||||||
|
const clearAccountSearchStateByKey = (key: string) => {
|
||||||
|
delete accountSearchKeyword.value[key]
|
||||||
|
delete accountSearchResults.value[key]
|
||||||
|
delete showAccountDropdown.value[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
const clearAllAccountSearchState = () => {
|
||||||
|
accountSearchKeyword.value = {}
|
||||||
|
accountSearchResults.value = {}
|
||||||
|
showAccountDropdown.value = {}
|
||||||
|
}
|
||||||
|
|
||||||
|
const accountSearchRunner = useKeyedDebouncedSearch<SimpleAccount[]>({
|
||||||
|
delay: 300,
|
||||||
|
search: async (keyword, { signal }) => {
|
||||||
|
const res = await adminAPI.accounts.list(
|
||||||
|
1,
|
||||||
|
20,
|
||||||
|
{
|
||||||
search: keyword,
|
search: keyword,
|
||||||
platform: 'anthropic'
|
platform: 'anthropic'
|
||||||
})
|
},
|
||||||
accountSearchResults.value[key] = res.items.map((a) => ({ id: a.id, name: a.name }))
|
{ signal }
|
||||||
} catch {
|
)
|
||||||
accountSearchResults.value[key] = []
|
return res.items.map((account) => ({ id: account.id, name: account.name }))
|
||||||
}
|
},
|
||||||
}, 300)
|
onSuccess: (key, result) => {
|
||||||
|
accountSearchResults.value[key] = result
|
||||||
|
},
|
||||||
|
onError: (key) => {
|
||||||
|
accountSearchResults.value[key] = []
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 搜索账号(仅限 anthropic 平台)
|
||||||
|
const searchAccounts = (key: string) => {
|
||||||
|
accountSearchRunner.trigger(key, accountSearchKeyword.value[key] || '')
|
||||||
|
}
|
||||||
|
|
||||||
|
const searchAccountsByRule = (rule: ModelRoutingRule, isEdit: boolean = false) => {
|
||||||
|
searchAccounts(getRuleSearchKey(rule, isEdit))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选择账号
|
// 选择账号
|
||||||
const selectAccount = (ruleIndex: number, account: SimpleAccount, isEdit: boolean = false) => {
|
const selectAccount = (rule: ModelRoutingRule, account: SimpleAccount, isEdit: boolean = false) => {
|
||||||
const rules = isEdit ? editModelRoutingRules.value : createModelRoutingRules.value
|
|
||||||
const rule = rules[ruleIndex]
|
|
||||||
if (!rule) return
|
if (!rule) return
|
||||||
|
|
||||||
// 检查是否已选择
|
// 检查是否已选择
|
||||||
@@ -1946,15 +1985,13 @@ const selectAccount = (ruleIndex: number, account: SimpleAccount, isEdit: boolea
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 清空搜索
|
// 清空搜索
|
||||||
const key = `${isEdit ? 'edit' : 'create'}-${ruleIndex}`
|
const key = getRuleSearchKey(rule, isEdit)
|
||||||
accountSearchKeyword.value[key] = ''
|
accountSearchKeyword.value[key] = ''
|
||||||
showAccountDropdown.value[key] = false
|
showAccountDropdown.value[key] = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 移除已选账号
|
// 移除已选账号
|
||||||
const removeSelectedAccount = (ruleIndex: number, accountId: number, isEdit: boolean = false) => {
|
const removeSelectedAccount = (rule: ModelRoutingRule, accountId: number, _isEdit: boolean = false) => {
|
||||||
const rules = isEdit ? editModelRoutingRules.value : createModelRoutingRules.value
|
|
||||||
const rule = rules[ruleIndex]
|
|
||||||
if (!rule) return
|
if (!rule) return
|
||||||
|
|
||||||
rule.accounts = rule.accounts.filter(a => a.id !== accountId)
|
rule.accounts = rule.accounts.filter(a => a.id !== accountId)
|
||||||
@@ -1981,8 +2018,8 @@ const toggleEditScope = (scope: string) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理账号搜索输入框聚焦
|
// 处理账号搜索输入框聚焦
|
||||||
const onAccountSearchFocus = (ruleIndex: number, isEdit: boolean = false) => {
|
const onAccountSearchFocus = (rule: ModelRoutingRule, isEdit: boolean = false) => {
|
||||||
const key = `${isEdit ? 'edit' : 'create'}-${ruleIndex}`
|
const key = getRuleSearchKey(rule, isEdit)
|
||||||
showAccountDropdown.value[key] = true
|
showAccountDropdown.value[key] = true
|
||||||
// 如果没有搜索结果,触发一次搜索
|
// 如果没有搜索结果,触发一次搜索
|
||||||
if (!accountSearchResults.value[key]?.length) {
|
if (!accountSearchResults.value[key]?.length) {
|
||||||
@@ -1996,13 +2033,14 @@ const addCreateRoutingRule = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 删除创建表单的路由规则
|
// 删除创建表单的路由规则
|
||||||
const removeCreateRoutingRule = (index: number) => {
|
const removeCreateRoutingRule = (rule: ModelRoutingRule) => {
|
||||||
|
const index = createModelRoutingRules.value.indexOf(rule)
|
||||||
|
if (index === -1) return
|
||||||
|
|
||||||
|
const key = getCreateRuleSearchKey(rule)
|
||||||
|
accountSearchRunner.clearKey(key)
|
||||||
|
clearAccountSearchStateByKey(key)
|
||||||
createModelRoutingRules.value.splice(index, 1)
|
createModelRoutingRules.value.splice(index, 1)
|
||||||
// 清理相关的搜索状态
|
|
||||||
const key = `create-${index}`
|
|
||||||
delete accountSearchKeyword.value[key]
|
|
||||||
delete accountSearchResults.value[key]
|
|
||||||
delete showAccountDropdown.value[key]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加编辑表单的路由规则
|
// 添加编辑表单的路由规则
|
||||||
@@ -2011,13 +2049,14 @@ const addEditRoutingRule = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 删除编辑表单的路由规则
|
// 删除编辑表单的路由规则
|
||||||
const removeEditRoutingRule = (index: number) => {
|
const removeEditRoutingRule = (rule: ModelRoutingRule) => {
|
||||||
|
const index = editModelRoutingRules.value.indexOf(rule)
|
||||||
|
if (index === -1) return
|
||||||
|
|
||||||
|
const key = getEditRuleSearchKey(rule)
|
||||||
|
accountSearchRunner.clearKey(key)
|
||||||
|
clearAccountSearchStateByKey(key)
|
||||||
editModelRoutingRules.value.splice(index, 1)
|
editModelRoutingRules.value.splice(index, 1)
|
||||||
// 清理相关的搜索状态
|
|
||||||
const key = `edit-${index}`
|
|
||||||
delete accountSearchKeyword.value[key]
|
|
||||||
delete accountSearchResults.value[key]
|
|
||||||
delete showAccountDropdown.value[key]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 将 UI 格式的路由规则转换为 API 格式
|
// 将 UI 格式的路由规则转换为 API 格式
|
||||||
@@ -2161,6 +2200,10 @@ const handlePageSizeChange = (pageSize: number) => {
|
|||||||
|
|
||||||
const closeCreateModal = () => {
|
const closeCreateModal = () => {
|
||||||
showCreateModal.value = false
|
showCreateModal.value = false
|
||||||
|
createModelRoutingRules.value.forEach((rule) => {
|
||||||
|
accountSearchRunner.clearKey(getCreateRuleSearchKey(rule))
|
||||||
|
})
|
||||||
|
clearAllAccountSearchState()
|
||||||
createForm.name = ''
|
createForm.name = ''
|
||||||
createForm.description = ''
|
createForm.description = ''
|
||||||
createForm.platform = 'anthropic'
|
createForm.platform = 'anthropic'
|
||||||
@@ -2247,6 +2290,10 @@ const handleEdit = async (group: AdminGroup) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const closeEditModal = () => {
|
const closeEditModal = () => {
|
||||||
|
editModelRoutingRules.value.forEach((rule) => {
|
||||||
|
accountSearchRunner.clearKey(getEditRuleSearchKey(rule))
|
||||||
|
})
|
||||||
|
clearAllAccountSearchState()
|
||||||
showEditModal.value = false
|
showEditModal.value = false
|
||||||
editingGroup.value = null
|
editingGroup.value = null
|
||||||
editModelRoutingRules.value = []
|
editModelRoutingRules.value = []
|
||||||
@@ -2382,5 +2429,7 @@ onMounted(() => {
|
|||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
document.removeEventListener('click', handleClickOutside)
|
document.removeEventListener('click', handleClickOutside)
|
||||||
|
accountSearchRunner.clearAll()
|
||||||
|
clearAllAccountSearchState()
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -94,57 +94,44 @@ const exportToExcel = async () => {
|
|||||||
if (exporting.value) return; exporting.value = true; exportProgress.show = true
|
if (exporting.value) return; exporting.value = true; exportProgress.show = true
|
||||||
const c = new AbortController(); exportAbortController = c
|
const c = new AbortController(); exportAbortController = c
|
||||||
try {
|
try {
|
||||||
const all: AdminUsageLog[] = []; let p = 1; let total = pagination.total
|
let p = 1; let total = pagination.total; let exportedCount = 0
|
||||||
|
const XLSX = await import('xlsx')
|
||||||
|
const headers = [
|
||||||
|
t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'),
|
||||||
|
t('admin.usage.account'), t('usage.model'), t('usage.reasoningEffort'), t('admin.usage.group'),
|
||||||
|
t('usage.type'),
|
||||||
|
t('admin.usage.inputTokens'), t('admin.usage.outputTokens'),
|
||||||
|
t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'),
|
||||||
|
t('admin.usage.inputCost'), t('admin.usage.outputCost'),
|
||||||
|
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
|
||||||
|
t('usage.rate'), t('usage.accountMultiplier'), t('usage.original'), t('usage.userBilled'), t('usage.accountBilled'),
|
||||||
|
t('usage.firstToken'), t('usage.duration'),
|
||||||
|
t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
|
||||||
|
]
|
||||||
|
const ws = XLSX.utils.aoa_to_sheet([headers])
|
||||||
while (true) {
|
while (true) {
|
||||||
const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
|
const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
|
||||||
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }
|
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }
|
||||||
if (res.items?.length) all.push(...res.items)
|
const rows = (res.items || []).map((log: AdminUsageLog) => [
|
||||||
exportProgress.current = all.length; exportProgress.progress = total > 0 ? Math.min(100, Math.round(all.length/total*100)) : 0
|
log.created_at, log.user?.email || '', log.api_key?.name || '', log.account?.name || '', log.model,
|
||||||
if (all.length >= total || res.items.length < 100) break; p++
|
formatReasoningEffort(log.reasoning_effort), log.group?.name || '', log.stream ? t('usage.stream') : t('usage.sync'),
|
||||||
|
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
|
||||||
|
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
|
||||||
|
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
|
||||||
|
log.rate_multiplier?.toFixed(2) || '1.00', (log.account_rate_multiplier ?? 1).toFixed(2),
|
||||||
|
log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000',
|
||||||
|
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6), log.first_token_ms ?? '', log.duration_ms,
|
||||||
|
log.request_id || '', log.user_agent || '', log.ip_address || ''
|
||||||
|
])
|
||||||
|
if (rows.length) {
|
||||||
|
XLSX.utils.sheet_add_aoa(ws, rows, { origin: -1 })
|
||||||
|
}
|
||||||
|
exportedCount += rows.length
|
||||||
|
exportProgress.current = exportedCount
|
||||||
|
exportProgress.progress = total > 0 ? Math.min(100, Math.round(exportedCount / total * 100)) : 0
|
||||||
|
if (exportedCount >= total || res.items.length < 100) break; p++
|
||||||
}
|
}
|
||||||
if(!c.signal.aborted) {
|
if(!c.signal.aborted) {
|
||||||
const XLSX = await import('xlsx')
|
|
||||||
const headers = [
|
|
||||||
t('usage.time'), t('admin.usage.user'), t('usage.apiKeyFilter'),
|
|
||||||
t('admin.usage.account'), t('usage.model'), t('usage.reasoningEffort'), t('admin.usage.group'),
|
|
||||||
t('usage.type'),
|
|
||||||
t('admin.usage.inputTokens'), t('admin.usage.outputTokens'),
|
|
||||||
t('admin.usage.cacheReadTokens'), t('admin.usage.cacheCreationTokens'),
|
|
||||||
t('admin.usage.inputCost'), t('admin.usage.outputCost'),
|
|
||||||
t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'),
|
|
||||||
t('usage.rate'), t('usage.accountMultiplier'), t('usage.original'), t('usage.userBilled'), t('usage.accountBilled'),
|
|
||||||
t('usage.firstToken'), t('usage.duration'),
|
|
||||||
t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress')
|
|
||||||
]
|
|
||||||
const rows = all.map(log => [
|
|
||||||
log.created_at,
|
|
||||||
log.user?.email || '',
|
|
||||||
log.api_key?.name || '',
|
|
||||||
log.account?.name || '',
|
|
||||||
log.model,
|
|
||||||
formatReasoningEffort(log.reasoning_effort),
|
|
||||||
log.group?.name || '',
|
|
||||||
log.stream ? t('usage.stream') : t('usage.sync'),
|
|
||||||
log.input_tokens,
|
|
||||||
log.output_tokens,
|
|
||||||
log.cache_read_tokens,
|
|
||||||
log.cache_creation_tokens,
|
|
||||||
log.input_cost?.toFixed(6) || '0.000000',
|
|
||||||
log.output_cost?.toFixed(6) || '0.000000',
|
|
||||||
log.cache_read_cost?.toFixed(6) || '0.000000',
|
|
||||||
log.cache_creation_cost?.toFixed(6) || '0.000000',
|
|
||||||
log.rate_multiplier?.toFixed(2) || '1.00',
|
|
||||||
(log.account_rate_multiplier ?? 1).toFixed(2),
|
|
||||||
log.total_cost?.toFixed(6) || '0.000000',
|
|
||||||
log.actual_cost?.toFixed(6) || '0.000000',
|
|
||||||
(log.total_cost * (log.account_rate_multiplier ?? 1)).toFixed(6),
|
|
||||||
log.first_token_ms ?? '',
|
|
||||||
log.duration_ms,
|
|
||||||
log.request_id || '',
|
|
||||||
log.user_agent || '',
|
|
||||||
log.ip_address || ''
|
|
||||||
])
|
|
||||||
const ws = XLSX.utils.aoa_to_sheet([headers, ...rows])
|
|
||||||
const wb = XLSX.utils.book_new()
|
const wb = XLSX.utils.book_new()
|
||||||
XLSX.utils.book_append_sheet(wb, ws, 'Usage')
|
XLSX.utils.book_append_sheet(wb, ws, 'Usage')
|
||||||
saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${filters.value.start_date}_to_${filters.value.end_date}.xlsx`)
|
saveAs(new Blob([XLSX.write(wb, { bookType: 'xlsx', type: 'array' })], { type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' }), `usage_${filters.value.start_date}_to_${filters.value.end_date}.xlsx`)
|
||||||
|
|||||||
Reference in New Issue
Block a user