diff --git a/common/ip.go b/common/ip.go new file mode 100644 index 00000000..bfb64ee7 --- /dev/null +++ b/common/ip.go @@ -0,0 +1,22 @@ +package common + +import "net" + +func IsPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + private := []net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, + } + + for _, privateNet := range private { + if privateNet.Contains(ip) { + return true + } + } + return false +} diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go new file mode 100644 index 00000000..6f7d289f --- /dev/null +++ b/common/ssrf_protection.go @@ -0,0 +1,327 @@ +package common + +import ( + "fmt" + "net" + "net/url" + "strconv" + "strings" +) + +// SSRFProtection SSRF防护配置 +type SSRFProtection struct { + AllowPrivateIp bool + DomainFilterMode bool // true: 白名单, false: 黑名单 + DomainList []string // domain format, e.g. example.com, *.example.com + IpFilterMode bool // true: 白名单, false: 黑名单 + IpList []string // CIDR or single IP + AllowedPorts []int // 允许的端口范围 + ApplyIPFilterForDomain bool // 对域名启用IP过滤 +} + +// DefaultSSRFProtection 默认SSRF防护配置 +var DefaultSSRFProtection = &SSRFProtection{ + AllowPrivateIp: false, + DomainFilterMode: true, + DomainList: []string{}, + IpFilterMode: true, + IpList: []string{}, + AllowedPorts: []int{}, +} + +// isPrivateIP 检查IP是否为私有地址 +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + // 检查私有网段 + private := []net.IPNet{ + {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8 + {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12 + {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16 + {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8 + {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地) + {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播) + {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留) + } + + for _, privateNet := range private { + if privateNet.Contains(ip) { + return true + } + } + + // 检查IPv6私有地址 + if ip.To4() == nil { + // IPv6 loopback + if ip.Equal(net.IPv6loopback) { + return true + } + // IPv6 link-local + if strings.HasPrefix(ip.String(), "fe80:") { + return true + } + // IPv6 unique local + if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") { + return true + } + } + + return false +} + +// parsePortRanges 解析端口范围配置 +// 支持格式: "80", "443", "8000-9000" +func parsePortRanges(portConfigs []string) ([]int, error) { + var ports []int + + for _, config := range portConfigs { + config = strings.TrimSpace(config) + if config == "" { + continue + } + + if strings.Contains(config, "-") { + // 处理端口范围 "8000-9000" + parts := strings.Split(config, "-") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid port range format: %s", config) + } + + startPort, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return nil, fmt.Errorf("invalid start port in range %s: %v", config, err) + } + + endPort, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return nil, fmt.Errorf("invalid end port in range %s: %v", config, err) + } + + if startPort > endPort { + return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config) + } + + if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 { + return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config) + } + + // 添加范围内的所有端口 + for port := startPort; port <= endPort; port++ { + ports = append(ports, port) + } + } else { + // 处理单个端口 "80" + port, err := strconv.Atoi(config) + if err != nil { + return nil, fmt.Errorf("invalid port number: %s", config) + } + + if port < 1 || port > 65535 { + return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port) + } + + ports = append(ports, port) + } + } + + return ports, nil +} + +// isAllowedPort 检查端口是否被允许 +func (p *SSRFProtection) isAllowedPort(port int) bool { + if len(p.AllowedPorts) == 0 { + return true // 如果没有配置端口限制,则允许所有端口 + } + + for _, allowedPort := range p.AllowedPorts { + if port == allowedPort { + return true + } + } + return false +} + +// isDomainWhitelisted 检查域名是否在白名单中 +func isDomainListed(domain string, list []string) bool { + if len(list) == 0 { + return false + } + + domain = strings.ToLower(domain) + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } + // 精确匹配 + if domain == item { + return true + } + // 通配符匹配 (*.example.com) + if strings.HasPrefix(item, "*.") { + suffix := strings.TrimPrefix(item, "*.") + if strings.HasSuffix(domain, "."+suffix) || domain == suffix { + return true + } + } + } + return false +} + +func (p *SSRFProtection) isDomainAllowed(domain string) bool { + listed := isDomainListed(domain, p.DomainList) + if p.DomainFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + +// isIPWhitelisted 检查IP是否在白名单中 + +func isIPListed(ip net.IP, list []string) bool { + if len(list) == 0 { + return false + } + + for _, whitelistCIDR := range list { + _, network, err := net.ParseCIDR(whitelistCIDR) + if err != nil { + // 尝试作为单个IP处理 + if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil { + if ip.Equal(whitelistIP) { + return true + } + } + continue + } + + if network.Contains(ip) { + return true + } + } + return false +} + +// IsIPAccessAllowed 检查IP是否允许访问 +func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool { + // 私有IP限制 + if isPrivateIP(ip) && !p.AllowPrivateIp { + return false + } + + listed := isIPListed(ip, p.IpList) + if p.IpFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + +// ValidateURL 验证URL是否安全 +func (p *SSRFProtection) ValidateURL(urlStr string) error { + // 解析URL + u, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("invalid URL format: %v", err) + } + + // 只允许HTTP/HTTPS协议 + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme) + } + + // 解析主机和端口 + host, portStr, err := net.SplitHostPort(u.Host) + if err != nil { + // 没有端口,使用默认端口 + host = u.Hostname() + if u.Scheme == "https" { + portStr = "443" + } else { + portStr = "80" + } + } + + // 验证端口 + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("invalid port: %s", portStr) + } + + if !p.isAllowedPort(port) { + return fmt.Errorf("port %d is not allowed", port) + } + + // 如果 host 是 IP,则跳过域名检查 + if ip := net.ParseIP(host); ip != nil { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) { + return fmt.Errorf("private IP address not allowed: %s", ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s", ip.String()) + } + return fmt.Errorf("ip in blacklist: %s", ip.String()) + } + return nil + } + + // 先进行域名过滤 + if !p.isDomainAllowed(host) { + if p.DomainFilterMode { + return fmt.Errorf("domain not in whitelist: %s", host) + } + return fmt.Errorf("domain in blacklist: %s", host) + } + + // 若未启用对域名应用IP过滤,则到此通过 + if !p.ApplyIPFilterForDomain { + return nil + } + + // 解析域名对应IP并检查 + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("DNS resolution failed for %s: %v", host, err) + } + for _, ip := range ips { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) && !p.AllowPrivateIp { + return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String()) + } + return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String()) + } + } + return nil +} + +// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL +func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error { + // 如果SSRF防护被禁用,直接返回成功 + if !enableSSRFProtection { + return nil + } + + // 解析端口范围配置 + allowedPortInts, err := parsePortRanges(allowedPorts) + if err != nil { + return fmt.Errorf("request reject - invalid port configuration: %v", err) + } + + protection := &SSRFProtection{ + AllowPrivateIp: allowPrivateIp, + DomainFilterMode: domainFilterMode, + DomainList: domainList, + IpFilterMode: ipFilterMode, + IpList: ipList, + AllowedPorts: allowedPortInts, + ApplyIPFilterForDomain: applyIPFilterForDomain, + } + return protection.ValidateURL(urlStr) +} diff --git a/constant/task.go b/constant/task.go index 21790145..e174fd60 100644 --- a/constant/task.go +++ b/constant/task.go @@ -11,8 +11,10 @@ const ( SunoActionMusic = "MUSIC" SunoActionLyrics = "LYRICS" - TaskActionGenerate = "generate" - TaskActionTextGenerate = "textGenerate" + TaskActionGenerate = "generate" + TaskActionTextGenerate = "textGenerate" + TaskActionFirstTailGenerate = "firstTailGenerate" + TaskActionReferenceGenerate = "referenceGenerate" ) var SunoModel2Action = map[string]string{ diff --git a/controller/channel-test.go b/controller/channel-test.go index 5a668c48..9ea6eed7 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult { requestPath = "/v1/embeddings" // 修改请求路径 } + // VolcEngine 图像生成模型 + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + } + c.Request = &http.Request{ Method: "POST", URL: &url.URL{Path: requestPath}, // 使用动态路径 @@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } + // 重新检查模型类型并更新请求路径 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || + strings.Contains(testModel, "bge-") || + strings.Contains(testModel, "embed") || + channel.Type == constant.ChannelTypeMokaAI { + requestPath = "/v1/embeddings" + c.Request.URL.Path = requestPath + } + + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + c.Request.URL.Path = requestPath + } + cache, err := model.GetUserCache(1) if err != nil { return testResult{ @@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { if c.Request.URL.Path == "/v1/embeddings" { relayFormat = types.RelayFormatEmbedding } + if c.Request.URL.Path == "/v1/images/generations" { + relayFormat = types.RelayFormatOpenAIImage + } info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) @@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult { } // 调用专门用于 Embedding 的转换函数 convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest) + } else if info.RelayMode == relayconstant.RelayModeImagesGenerations { + // 创建一个 ImageRequest + prompt := "cat" + if request.Prompt != nil { + if promptStr, ok := request.Prompt.(string); ok && promptStr != "" { + prompt = promptStr + } + } + imageRequest := dto.ImageRequest{ + Prompt: prompt, + Model: request.Model, + N: uint(request.N), + Size: request.Size, + } + // 调用专门用于图像生成的转换函数 + convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest) } else { // 对其他所有请求类型(如 Chat),保持原有逻辑 convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request) diff --git a/controller/channel.go b/controller/channel.go index 403eb04c..17154ab0 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -501,9 +501,10 @@ func validateChannel(channel *model.Channel, isAdd bool) error { } type AddChannelRequest struct { - Mode string `json:"mode"` - MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` - Channel *model.Channel `json:"channel"` + Mode string `json:"mode"` + MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` + BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"` + Channel *model.Channel `json:"channel"` } func getVertexArrayKeys(keys string) ([]string, error) { @@ -616,6 +617,13 @@ func AddChannel(c *gin.Context) { } localChannel := addChannelRequest.Channel localChannel.Key = key + if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 { + keyPrefix := localChannel.Key + if len(localChannel.Key) > 8 { + keyPrefix = localChannel.Key[:8] + } + localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix) + } channels = append(channels, *localChannel) } err = model.BatchInsertChannels(channels) diff --git a/controller/option.go b/controller/option.go index e5f2b75b..7d1c676f 100644 --- a/controller/option.go +++ b/controller/option.go @@ -128,6 +128,33 @@ func UpdateOption(c *gin.Context) { }) return } + case "ImageRatio": + err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "图片倍率设置失败: " + err.Error(), + }) + return + } + case "AudioRatio": + err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "音频倍率设置失败: " + err.Error(), + }) + return + } + case "AudioCompletionRatio": + err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "音频补全倍率设置失败: " + err.Error(), + }) + return + } case "ModelRequestRateLimitGroup": err = setting.CheckModelRequestRateLimitGroup(option.Value.(string)) if err != nil { diff --git a/controller/setup.go b/controller/setup.go index 44a7b3a7..3ae255e9 100644 --- a/controller/setup.go +++ b/controller/setup.go @@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) { func PostSetup(c *gin.Context) { // Check if setup is already completed if constant.Setup { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "系统已经初始化完成", }) @@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) { var req SetupRequest err := c.ShouldBindJSON(&req) if err != nil { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "请求参数有误", }) @@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) { if !rootExists { // Validate username length: max 12 characters to align with model.User validation if len(req.Username) > 12 { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "用户名长度不能超过12个字符", }) @@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) { } // Validate password if req.Password != req.ConfirmPassword { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "两次输入的密码不一致", }) @@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) { } if len(req.Password) < 8 { - c.JSON(400, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "密码长度至少为8个字符", }) @@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) { // Create root user hashedPassword, err := common.Password2Hash(req.Password) if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "系统错误: " + err.Error(), }) @@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) { } err = model.DB.Create(&rootUser).Error if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "创建管理员账号失败: " + err.Error(), }) @@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) { // Save operation modes to database for persistence err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled)) if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "保存自用模式设置失败: " + err.Error(), }) @@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) { err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled)) if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "保存演示站点模式设置失败: " + err.Error(), }) @@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) { } err = model.DB.Create(&setup).Error if err != nil { - c.JSON(500, gin.H{ + c.JSON(200, gin.H{ "success": false, "message": "系统初始化失败: " + err.Error(), }) diff --git a/dto/openai_request.go b/dto/openai_request.go index cd05a63c..191fa638 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -772,11 +772,12 @@ type OpenAIResponsesRequest struct { Instructions json.RawMessage `json:"instructions,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"` ServiceTier string `json:"service_tier,omitempty"` - Store bool `json:"store,omitempty"` + Store json.RawMessage `json:"store,omitempty"` + PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"` Stream bool `json:"stream,omitempty"` Temperature float64 `json:"temperature,omitempty"` Text json.RawMessage `json:"text,omitempty"` diff --git a/dto/openai_response.go b/dto/openai_response.go index 966748cb..6353c15f 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -6,6 +6,10 @@ import ( "one-api/types" ) +const ( + ResponsesOutputTypeImageGenerationCall = "image_generation_call" +) + type SimpleResponse struct { Usage `json:"usage"` Error any `json:"error"` @@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError { return GetOpenAIError(o.Error) } +func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool { + if len(o.Output) == 0 { + return false + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return true + } + } + return false +} + +func (o *OpenAIResponsesResponse) GetQuality() string { + if len(o.Output) == 0 { + return "" + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return output.Quality + } + } + return "" +} + +func (o *OpenAIResponsesResponse) GetSize() string { + if len(o.Output) == 0 { + return "" + } + for _, output := range o.Output { + if output.Type == ResponsesOutputTypeImageGenerationCall { + return output.Size + } + } + return "" +} + type IncompleteDetails struct { Reasoning string `json:"reasoning"` } @@ -283,6 +323,8 @@ type ResponsesOutput struct { Status string `json:"status"` Role string `json:"role"` Content []ResponsesOutputContent `json:"content"` + Quality string `json:"quality"` + Size string `json:"size"` } type ResponsesOutputContent struct { diff --git a/model/option.go b/model/option.go index fefee4e7..ceecff65 100644 --- a/model/option.go +++ b/model/option.go @@ -112,6 +112,9 @@ func InitOptionMap() { common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString() common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString() + common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString() + common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString() + common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink //common.OptionMap["ChatLink"] = common.ChatLink //common.OptionMap["ChatLink2"] = common.ChatLink2 @@ -397,6 +400,12 @@ func updateOptionMap(key string, value string) (err error) { err = ratio_setting.UpdateModelPriceByJSONString(value) case "CacheRatio": err = ratio_setting.UpdateCacheRatioByJSONString(value) + case "ImageRatio": + err = ratio_setting.UpdateImageRatioByJSONString(value) + case "AudioRatio": + err = ratio_setting.UpdateAudioRatioByJSONString(value) + case "AudioCompletionRatio": + err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value) case "TopUpLink": common.TopUpLink = value //case "ChatLink": diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 17d732ab..962f8794 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -3,17 +3,17 @@ package deepseek import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" "strings" - - "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - adaptor := openai.Adaptor{} + adaptor := claude.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) } @@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fimBaseUrl := info.ChannelBaseUrl - if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { - fimBaseUrl += "/beta" - } - switch info.RelayMode { - case constant.RelayModeCompletions: - return fmt.Sprintf("%s/completions", fimBaseUrl), nil + switch info.RelayFormat { + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { + fimBaseUrl += "/beta" + } + switch info.RelayMode { + case constant.RelayModeCompletions: + return fmt.Sprintf("%s/completions", fimBaseUrl), nil + default: + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil + } } } @@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + switch info.RelayFormat { + case types.RelayFormatClaude: + if info.IsStream { + return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) + } else { + return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + adaptor := openai.Adaptor{} + return adaptor.DoResponse(c, resp, info) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 4968f78f..57542aa5 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { - if strings.HasSuffix(info.RequestURLPath, ":embedContent") || - strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") { + if strings.Contains(info.RequestURLPath, ":embedContent") || + strings.Contains(info.RequestURLPath, ":batchEmbedContents") { return NativeGeminiEmbeddingHandler(c, resp, info) } if info.IsStream { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index eb4afbae..199c8466 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -23,6 +23,7 @@ import ( "github.com/gin-gonic/gin" ) +// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob var geminiSupportedMimeTypes = map[string]bool{ "application/pdf": true, "audio/mpeg": true, @@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{ "audio/wav": true, "image/png": true, "image/jpeg": true, + "image/webp": true, "text/plain": true, "video/mov": true, "video/mpeg": true, diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index e290c239..f24976bb 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { - adaptor := openai.Adaptor{} + adaptor := claude.Adaptor{} return adaptor.ConvertClaudeRequest(c, info, req) } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index e188889e..85938a77 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } + if responsesResponse.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", responsesResponse.GetQuality()) + c.Set("image_generation_call_size", responsesResponse.GetSize()) + } + // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) @@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": - if streamResponse.Response != nil && streamResponse.Response.Usage != nil { - if streamResponse.Response.Usage.InputTokens != 0 { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens + if streamResponse.Response != nil { + if streamResponse.Response.Usage != nil { + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + } + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } + if streamResponse.Response.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + } } - if streamResponse.Response.Usage.OutputTokens != 0 { - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - } - if streamResponse.Response.Usage.TotalTokens != 0 { - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - } - if streamResponse.Response.Usage.InputTokensDetails != nil { - usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + if streamResponse.Response.HasImageGenerationCall() { + c.Set("image_generation_call", true) + c.Set("image_generation_call_quality", streamResponse.Response.GetQuality()) + c.Set("image_generation_call_size", streamResponse.Response.GetSize()) } } case "response.output_text.delta": diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index b954d7b8..a2545a27 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -94,6 +94,9 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + if isNewAPIRelay(info.ApiKey) { + return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil + } return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil } @@ -101,7 +104,12 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - return a.signRequest(req, a.accessKey, a.secretKey) + if isNewAPIRelay(info.ApiKey) { + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + } else { + return a.signRequest(req, a.accessKey, a.secretKey) + } + return nil } // BuildRequestBody converts request into Jimeng specific format. @@ -161,6 +169,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http } uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl) + if isNewAPIRelay(key) { + uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL) + } payload := map[string]string{ "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 "task_id": taskID, @@ -178,17 +189,20 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - keyParts := strings.Split(key, "|") - if len(keyParts) != 2 { - return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") - } - accessKey := strings.TrimSpace(keyParts[0]) - secretKey := strings.TrimSpace(keyParts[1]) + if isNewAPIRelay(key) { + req.Header.Set("Authorization", "Bearer "+key) + } else { + keyParts := strings.Split(key, "|") + if len(keyParts) != 2 { + return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") + } + accessKey := strings.TrimSpace(keyParts[0]) + secretKey := strings.TrimSpace(keyParts[1]) - if err := a.signRequest(req, accessKey, secretKey); err != nil { - return nil, errors.Wrap(err, "sign request failed") + if err := a.signRequest(req, accessKey, secretKey); err != nil { + return nil, errors.Wrap(err, "sign request failed") + } } - return service.GetHttpClient().Do(req) } @@ -384,3 +398,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Url = resTask.Data.VideoUrl return &taskResult, nil } + +func isNewAPIRelay(apiKey string) bool { + return strings.HasPrefix(apiKey, "sk-") +} diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 13f2af97..fec3396a 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -117,6 +117,11 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") + + if isNewAPIRelay(info.ApiKey) { + return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil + } + return fmt.Sprintf("%s%s", a.baseURL, path), nil } @@ -199,6 +204,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http } path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID) + if isNewAPIRelay(key) { + url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID) + } req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -304,8 +312,13 @@ func (a *TaskAdaptor) createJWTToken() (string, error) { //} func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { - + if isNewAPIRelay(apiKey) { + return apiKey, nil // new api relay + } keyParts := strings.Split(apiKey, "|") + if len(keyParts) != 2 { + return "", errors.New("invalid api_key, required format is accessKey|secretKey") + } accessKey := strings.TrimSpace(keyParts[0]) if len(keyParts) == 1 { return accessKey, nil @@ -352,3 +365,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } return taskInfo, nil } + +func isNewAPIRelay(apiKey string) bool { + return strings.HasPrefix(apiKey, "sk-") +} diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index a1140d1e..358aef58 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { - // Use the unified validation method for TaskSubmitReq with image-based action determination - return relaycommon.ValidateTaskRequestWithImageBinding(c, info) + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { @@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro switch info.Action { case constant.TaskActionGenerate: path = "/img2video" + case constant.TaskActionFirstTailGenerate: + path = "/start-end2video" + case constant.TaskActionReferenceGenerate: + path = "/reference2video" default: path = "/text2video" } @@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string { // ============================ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { - var images []string - if req.Image != "" { - images = []string{req.Image} - } - r := requestPayload{ Model: defaultString(req.Model, "viduq1"), - Images: images, + Images: req.Images, Prompt: req.Prompt, Duration: defaultInt(req.Duration, 5), Resolution: defaultString(req.Size, "1080p"), diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 0af019da..eb88412a 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { switch info.RelayMode { + case constant.RelayModeImagesGenerations: + return request, nil case constant.RelayModeImagesEdits: var requestBody bytes.Buffer diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go index 30cc902e..fca10e7c 100644 --- a/relay/channel/volcengine/constants.go +++ b/relay/channel/volcengine/constants.go @@ -8,6 +8,7 @@ var ModelList = []string{ "Doubao-lite-32k", "Doubao-lite-4k", "Doubao-embedding", + "doubao-seedream-4-0-250828", } var ChannelName = "volcengine" diff --git a/relay/claude_handler.go b/relay/claude_handler.go index dbdc6ee1..59d12abe 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ info.UpstreamModelName = request.Model } + if info.ChannelSetting.SystemPrompt != "" { + if request.System == nil { + request.SetStringSystem(info.ChannelSetting.SystemPrompt) + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + if request.IsStringSystem() { + existing := strings.TrimSpace(request.GetStringSystem()) + if existing == "" { + request.SetStringSystem(info.ChannelSetting.SystemPrompt) + } else { + request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing) + } + } else { + systemContents := request.ParseSystem() + newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText} + newSystem.SetText(info.ChannelSetting.SystemPrompt) + if len(systemContents) == 0 { + request.System = []dto.ClaudeMediaMessage{newSystem} + } else { + request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...) + } + } + } + } + var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index cf6d08dd..3a721b47 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d req.Images = []string{req.Image} } + if req.HasImage() { + action = constant.TaskActionGenerate + if info.ChannelType == constant.ChannelTypeVidu { + // vidu 增加 首尾帧生视频和参考图生视频 + if len(req.Images) == 2 { + action = constant.TaskActionFirstTailGenerate + } else if len(req.Images) > 2 { + action = constant.TaskActionReferenceGenerate + } + } + } + storeTaskRequest(c, info, action, req) return nil } - -func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError { - hasPrompt, ok := requestObj.(HasPrompt) - if !ok { - return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true) - } - - if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil { - return taskErr - } - - action := constant.TaskActionTextGenerate - if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() { - action = constant.TaskActionGenerate - } - - storeTaskRequest(c, info, action, requestObj) - return nil -} - -func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError { - var req TaskSubmitReq - if err := c.ShouldBindJSON(&req); err != nil { - return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false) - } - - return ValidateTaskRequestWithImage(c, info, req) -} diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 01ab1fff..38b820f7 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -90,41 +90,43 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 - request := convertedRequest.(*dto.GeneralOpenAIRequest) - containSystemPrompt := false - for _, message := range request.Messages { - if message.Role == request.GetSystemRoleName() { - containSystemPrompt = true - break - } - } - if !containSystemPrompt { - // 如果没有系统提示,则添加系统提示 - systemMessage := dto.Message{ - Role: request.GetSystemRoleName(), - Content: info.ChannelSetting.SystemPrompt, - } - request.Messages = append([]dto.Message{systemMessage}, request.Messages...) - } else if info.ChannelSetting.SystemPromptOverride { - common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) - // 如果有系统提示,且允许覆盖,则拼接到前面 - for i, message := range request.Messages { + request, ok := convertedRequest.(*dto.GeneralOpenAIRequest) + if ok { + containSystemPrompt := false + for _, message := range request.Messages { if message.Role == request.GetSystemRoleName() { - if message.IsStringContent() { - request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) - } else { - contents := message.ParseContent() - contents = append([]dto.MediaContent{ - { - Type: dto.ContentTypeText, - Text: info.ChannelSetting.SystemPrompt, - }, - }, contents...) - request.Messages[i].Content = contents - } + containSystemPrompt = true break } } + if !containSystemPrompt { + // 如果没有系统提示,则添加系统提示 + systemMessage := dto.Message{ + Role: request.GetSystemRoleName(), + Content: info.ChannelSetting.SystemPrompt, + } + request.Messages = append([]dto.Message{systemMessage}, request.Messages...) + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + // 如果有系统提示,且允许覆盖,则拼接到前面 + for i, message := range request.Messages { + if message.Role == request.GetSystemRoleName() { + if message.IsStringContent() { + request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + } else { + contents := message.ParseContent() + contents = append([]dto.MediaContent{ + { + Type: dto.ContentTypeText, + Text: info.ChannelSetting.SystemPrompt, + }, + }, contents...) + request.Messages[i].Content = contents + } + break + } + } + } } } @@ -276,6 +278,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage fileSearchTool.CallCount, dFileSearchQuota.String()) } } + var dImageGenerationCallQuota decimal.Decimal + var imageGenerationCallPrice float64 + if ctx.GetBool("image_generation_call") { + imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) + dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()) + } var quotaCalculateDecimal decimal.Decimal @@ -331,6 +340,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) + // 添加 image generation call 计费 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -429,6 +440,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } + if !dImageGenerationCallQuota.IsZero() { + other["image_generation_call"] = true + other["image_generation_call_price"] = imageGenerationCallPrice + } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 0252d657..1410da60 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/logger" "one-api/relay/channel/gemini" @@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ adaptor.Init(info) + if info.ChannelSetting.SystemPrompt != "" { + if request.SystemInstructions == nil { + request.SystemInstructions = &dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ + {Text: info.ChannelSetting.SystemPrompt}, + }, + } + } else if len(request.SystemInstructions.Parts) == 0 { + request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}} + } else if info.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + merged := false + for i := range request.SystemInstructions.Parts { + if request.SystemInstructions.Parts[i].Text == "" { + continue + } + request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text + merged = true + break + } + if !merged { + request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...) + } + } + } + // Clean up empty system instruction if request.SystemInstructions != nil { hasContent := false diff --git a/relay/helper/price.go b/relay/helper/price.go index fdc5b66d..c23c068b 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -52,6 +52,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var cacheRatio float64 var imageRatio float64 var cacheCreationRatio float64 + var audioRatio float64 + var audioCompletionRatio float64 if !usePrice { preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota) if meta.MaxTokens != 0 { @@ -73,6 +75,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName) cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName) imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName) + audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName) + audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName) ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -90,6 +94,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens UsePrice: usePrice, CacheRatio: cacheRatio, ImageRatio: imageRatio, + AudioRatio: audioRatio, + AudioCompletionRatio: audioCompletionRatio, CacheCreationRatio: cacheCreationRatio, ShouldPreConsumedQuota: preConsumedQuota, } diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 4d1c1f9b..f4a290ec 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt case types.RelayFormatOpenAI: request, err = GetAndValidateTextRequest(c, relayMode) case types.RelayFormatGemini: - request, err = GetAndValidateGeminiRequest(c) + if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") { + request, err = GetAndValidateGeminiEmbeddingRequest(c) + } else { + request, err = GetAndValidateGeminiRequest(c) + } case types.RelayFormatClaude: request, err = GetAndValidateClaudeRequest(c) case types.RelayFormatOpenAIResponses: @@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA } func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) { - request := &dto.GeminiChatRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) return request, nil } + +func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) { + request := &dto.GeminiEmbeddingRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + return request, nil +} diff --git a/service/cf_worker.go b/service/download.go similarity index 59% rename from service/cf_worker.go rename to service/download.go index d60b6fad..036c43af 100644 --- a/service/cf_worker.go +++ b/service/download.go @@ -28,6 +28,12 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { return nil, fmt.Errorf("only support https url") } + // SSRF防护:验证请求URL + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return nil, fmt.Errorf("request reject: %v", err) + } + workerUrl := system_setting.WorkerUrl if !strings.HasSuffix(workerUrl, "/") { workerUrl += "/" @@ -51,7 +57,13 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, } return DoWorkerRequest(req) } else { - common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) + // SSRF防护:验证请求URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return nil, fmt.Errorf("request reject: %v", err) + } + + common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", "))) return http.Get(originUrl) } } diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 3cfabc1a..0cf53513 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -19,7 +19,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { gopool.Go(func() { relayInfoCopy := *relayInfo - err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false) + err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false) if err != nil { common.SysLog("error return pre-consumed quota: " + err.Error()) } diff --git a/service/user_notify.go b/service/user_notify.go index 972ca655..fba12d9d 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -113,6 +113,12 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) } } else { + // SSRF防护:验证Bark URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("request reject: %v", err) + } + // 直接发送请求 req, err = http.NewRequest(http.MethodGet, finalURL, nil) if err != nil { diff --git a/service/webhook.go b/service/webhook.go index 9c6ec810..c678b863 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "net/http" + "one-api/common" "one-api/dto" "one-api/setting/system_setting" "time" @@ -86,6 +87,12 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode) } } else { + // SSRF防护:验证Webhook URL(非Worker模式) + fetchSetting := system_setting.GetFetchSetting() + if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("request reject: %v", err) + } + req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes)) if err != nil { return fmt.Errorf("failed to create webhook request: %v", err) diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index 549a1862..5b89d6fe 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -10,6 +10,18 @@ const ( FileSearchPrice = 2.5 ) +const ( + GPTImage1Low1024x1024 = 0.011 + GPTImage1Low1024x1536 = 0.016 + GPTImage1Low1536x1024 = 0.016 + GPTImage1Medium1024x1024 = 0.042 + GPTImage1Medium1024x1536 = 0.063 + GPTImage1Medium1536x1024 = 0.063 + GPTImage1High1024x1024 = 0.167 + GPTImage1High1024x1536 = 0.25 + GPTImage1High1536x1024 = 0.25 +) + const ( // Gemini Audio Input Price Gemini25FlashPreviewInputAudioPrice = 1.00 @@ -65,3 +77,31 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { } return 0 } + +func GetGPTImage1PriceOnceCall(quality string, size string) float64 { + prices := map[string]map[string]float64{ + "low": { + "1024x1024": GPTImage1Low1024x1024, + "1024x1536": GPTImage1Low1024x1536, + "1536x1024": GPTImage1Low1536x1024, + }, + "medium": { + "1024x1024": GPTImage1Medium1024x1024, + "1024x1536": GPTImage1Medium1024x1536, + "1536x1024": GPTImage1Medium1536x1024, + }, + "high": { + "1024x1024": GPTImage1High1024x1024, + "1024x1536": GPTImage1High1024x1536, + "1536x1024": GPTImage1High1536x1024, + }, + } + + if qualityMap, exists := prices[quality]; exists { + if price, exists := qualityMap[size]; exists { + return price + } + } + + return GPTImage1High1024x1024 +} diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index f06cd71e..362c6fa1 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -178,6 +178,7 @@ var defaultModelRatio = map[string]float64{ "gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash": 0.15, + "gemini-embedding-001": 0.075, "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens @@ -278,6 +279,18 @@ var defaultModelPrice = map[string]float64{ "mj_upload": 0.05, } +var defaultAudioRatio = map[string]float64{ + "gpt-4o-audio-preview": 16, + "gpt-4o-mini-audio-preview": 66.67, + "gpt-4o-realtime-preview": 8, + "gpt-4o-mini-realtime-preview": 16.67, +} + +var defaultAudioCompletionRatio = map[string]float64{ + "gpt-4o-realtime": 2, + "gpt-4o-mini-realtime": 2, +} + var ( modelPriceMap map[string]float64 = nil modelPriceMapMutex = sync.RWMutex{} @@ -326,6 +339,15 @@ func InitRatioSettings() { imageRatioMap = defaultImageRatio imageRatioMapMutex.Unlock() + // initialize audioRatioMap + audioRatioMapMutex.Lock() + audioRatioMap = defaultAudioRatio + audioRatioMapMutex.Unlock() + + // initialize audioCompletionRatioMap + audioCompletionRatioMapMutex.Lock() + audioCompletionRatioMap = defaultAudioCompletionRatio + audioCompletionRatioMapMutex.Unlock() } func GetModelPriceMap() map[string]float64 { @@ -417,6 +439,18 @@ func GetDefaultModelRatioMap() map[string]float64 { return defaultModelRatio } +func GetDefaultImageRatioMap() map[string]float64 { + return defaultImageRatio +} + +func GetDefaultAudioRatioMap() map[string]float64 { + return defaultAudioRatio +} + +func GetDefaultAudioCompletionRatioMap() map[string]float64 { + return defaultAudioCompletionRatio +} + func GetCompletionRatioMap() map[string]float64 { CompletionRatioMutex.RLock() defer CompletionRatioMutex.RUnlock() @@ -584,32 +618,22 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { } func GetAudioRatio(name string) float64 { - if strings.Contains(name, "-realtime") { - if strings.HasSuffix(name, "gpt-4o-realtime-preview") { - return 8 - } else if strings.Contains(name, "gpt-4o-mini-realtime-preview") { - return 10 / 0.6 - } else { - return 20 - } - } - if strings.Contains(name, "-audio") { - if strings.HasPrefix(name, "gpt-4o-audio-preview") { - return 40 / 2.5 - } else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") { - return 10 / 0.15 - } else { - return 40 - } + audioRatioMapMutex.RLock() + defer audioRatioMapMutex.RUnlock() + name = FormatMatchingModelName(name) + if ratio, ok := audioRatioMap[name]; ok { + return ratio } return 20 } func GetAudioCompletionRatio(name string) float64 { - if strings.HasPrefix(name, "gpt-4o-realtime") { - return 2 - } else if strings.HasPrefix(name, "gpt-4o-mini-realtime") { - return 2 + audioCompletionRatioMapMutex.RLock() + defer audioCompletionRatioMapMutex.RUnlock() + name = FormatMatchingModelName(name) + if ratio, ok := audioCompletionRatioMap[name]; ok { + + return ratio } return 2 } @@ -630,6 +654,14 @@ var defaultImageRatio = map[string]float64{ } var imageRatioMap map[string]float64 var imageRatioMapMutex sync.RWMutex +var ( + audioRatioMap map[string]float64 = nil + audioRatioMapMutex = sync.RWMutex{} +) +var ( + audioCompletionRatioMap map[string]float64 = nil + audioCompletionRatioMapMutex = sync.RWMutex{} +) func ImageRatio2JSONString() string { imageRatioMapMutex.RLock() @@ -658,6 +690,71 @@ func GetImageRatio(name string) (float64, bool) { return ratio, true } +func AudioRatio2JSONString() string { + audioRatioMapMutex.RLock() + defer audioRatioMapMutex.RUnlock() + jsonBytes, err := common.Marshal(audioRatioMap) + if err != nil { + common.SysError("error marshalling audio ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateAudioRatioByJSONString(jsonStr string) error { + + tmp := make(map[string]float64) + if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil { + return err + } + audioRatioMapMutex.Lock() + audioRatioMap = tmp + audioRatioMapMutex.Unlock() + InvalidateExposedDataCache() + return nil +} + +func GetAudioRatioCopy() map[string]float64 { + audioRatioMapMutex.RLock() + defer audioRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(audioRatioMap)) + for k, v := range audioRatioMap { + copyMap[k] = v + } + return copyMap +} + +func AudioCompletionRatio2JSONString() string { + audioCompletionRatioMapMutex.RLock() + defer audioCompletionRatioMapMutex.RUnlock() + jsonBytes, err := common.Marshal(audioCompletionRatioMap) + if err != nil { + common.SysError("error marshalling audio completion ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateAudioCompletionRatioByJSONString(jsonStr string) error { + tmp := make(map[string]float64) + if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil { + return err + } + audioCompletionRatioMapMutex.Lock() + audioCompletionRatioMap = tmp + audioCompletionRatioMapMutex.Unlock() + InvalidateExposedDataCache() + return nil +} + +func GetAudioCompletionRatioCopy() map[string]float64 { + audioCompletionRatioMapMutex.RLock() + defer audioCompletionRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(audioCompletionRatioMap)) + for k, v := range audioCompletionRatioMap { + copyMap[k] = v + } + return copyMap +} + func GetModelRatioCopy() map[string]float64 { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() diff --git a/setting/system_setting/fetch_setting.go b/setting/system_setting/fetch_setting.go new file mode 100644 index 00000000..c41b930a --- /dev/null +++ b/setting/system_setting/fetch_setting.go @@ -0,0 +1,34 @@ +package system_setting + +import "one-api/setting/config" + +type FetchSetting struct { + EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护 + AllowPrivateIp bool `json:"allow_private_ip"` + DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式 + IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式 + DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com + IpList []string `json:"ip_list"` // CIDR format + AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000 + ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性) +} + +var defaultFetchSetting = FetchSetting{ + EnableSSRFProtection: true, // 默认开启SSRF防护 + AllowPrivateIp: false, + DomainFilterMode: false, + IpFilterMode: false, + DomainList: []string{}, + IpList: []string{}, + AllowedPorts: []string{"80", "443", "8080", "8443"}, + ApplyIPFilterForDomain: false, +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting) +} + +func GetFetchSetting() *FetchSetting { + return &defaultFetchSetting +} diff --git a/types/error.go b/types/error.go index 883ee064..a42e8438 100644 --- a/types/error.go +++ b/types/error.go @@ -122,6 +122,9 @@ func (e *NewAPIError) MaskSensitiveError() string { return string(e.errorCode) } errStr := e.Err.Error() + if e.errorCode == ErrorCodeCountTokenFailed { + return errStr + } return common.MaskSensitiveInfo(errStr) } @@ -153,8 +156,9 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { Code: e.errorCode, } } - - result.Message = common.MaskSensitiveInfo(result.Message) + if e.errorCode != ErrorCodeCountTokenFailed { + result.Message = common.MaskSensitiveInfo(result.Message) + } return result } @@ -178,7 +182,9 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { Type: string(e.errorType), } } - result.Message = common.MaskSensitiveInfo(result.Message) + if e.errorCode != ErrorCodeCountTokenFailed { + result.Message = common.MaskSensitiveInfo(result.Message) + } return result } diff --git a/types/price_data.go b/types/price_data.go index f6a92d7e..ec7fcdfe 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -15,6 +15,8 @@ type PriceData struct { CacheRatio float64 CacheCreationRatio float64 ImageRatio float64 + AudioRatio float64 + AudioCompletionRatio float64 UsePrice bool ShouldPreConsumedQuota int GroupRatioInfo GroupRatioInfo @@ -27,5 +29,5 @@ type PerCallPriceData struct { } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) } diff --git a/web/jsconfig.json b/web/jsconfig.json new file mode 100644 index 00000000..ced4d054 --- /dev/null +++ b/web/jsconfig.json @@ -0,0 +1,9 @@ +{ + "compilerOptions": { + "baseUrl": "./", + "paths": { + "@/*": ["src/*"] + } + }, + "include": ["src/**/*"] +} \ No newline at end of file diff --git a/web/src/components/layout/headerbar/UserArea.jsx b/web/src/components/layout/headerbar/UserArea.jsx index 8ea70f47..9fc011da 100644 --- a/web/src/components/layout/headerbar/UserArea.jsx +++ b/web/src/components/layout/headerbar/UserArea.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React from 'react'; +import React, { useRef } from 'react'; import { Link } from 'react-router-dom'; import { Avatar, Button, Dropdown, Typography } from '@douyinfe/semi-ui'; import { ChevronDown } from 'lucide-react'; @@ -39,6 +39,7 @@ const UserArea = ({ navigate, t, }) => { + const dropdownRef = useRef(null); if (isLoading) { return ( - { - navigate('/console/personal'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('个人设置')} -
-
- { - navigate('/console/token'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('令牌管理')} -
-
- { - navigate('/console/topup'); - }} - className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white' - > -
- - {t('钱包管理')} -
-
- -
- - {t('退出')} -
-
- - } - > - - + + {userState.user.username[0].toUpperCase()} + + + + {userState.user.username} + + + + + + ); } else { const showRegisterButton = !isSelfUseMode; diff --git a/web/src/components/settings/RatioSetting.jsx b/web/src/components/settings/RatioSetting.jsx index 096722bb..f5d8ef99 100644 --- a/web/src/components/settings/RatioSetting.jsx +++ b/web/src/components/settings/RatioSetting.jsx @@ -39,6 +39,9 @@ const RatioSetting = () => { CompletionRatio: '', GroupRatio: '', GroupGroupRatio: '', + ImageRatio: '', + AudioRatio: '', + AudioCompletionRatio: '', AutoGroups: '', DefaultUseAutoGroup: false, ExposeRatioEnabled: false, @@ -61,7 +64,10 @@ const RatioSetting = () => { item.key === 'UserUsableGroups' || item.key === 'CompletionRatio' || item.key === 'ModelPrice' || - item.key === 'CacheRatio' + item.key === 'CacheRatio' || + item.key === 'ImageRatio' || + item.key === 'AudioRatio' || + item.key === 'AudioCompletionRatio' ) { try { item.value = JSON.stringify(JSON.parse(item.value), null, 2); diff --git a/web/src/components/settings/SystemSetting.jsx b/web/src/components/settings/SystemSetting.jsx index 9c7eeaad..f9a2c019 100644 --- a/web/src/components/settings/SystemSetting.jsx +++ b/web/src/components/settings/SystemSetting.jsx @@ -29,6 +29,7 @@ import { TagInput, Spin, Card, + Radio, } from '@douyinfe/semi-ui'; const { Text } = Typography; import { @@ -44,6 +45,7 @@ import { useTranslation } from 'react-i18next'; const SystemSetting = () => { const { t } = useTranslation(); let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', PasswordRegisterEnabled: '', EmailVerificationEnabled: '', @@ -87,6 +89,15 @@ const SystemSetting = () => { LinuxDOClientSecret: '', LinuxDOMinimumTrustLevel: '', ServerAddress: '', + // SSRF防护配置 + 'fetch_setting.enable_ssrf_protection': true, + 'fetch_setting.allow_private_ip': '', + 'fetch_setting.domain_filter_mode': false, // true 白名单,false 黑名单 + 'fetch_setting.ip_filter_mode': false, // true 白名单,false 黑名单 + 'fetch_setting.domain_list': [], + 'fetch_setting.ip_list': [], + 'fetch_setting.allowed_ports': [], + 'fetch_setting.apply_ip_filter_for_domain': false, }); const [originInputs, setOriginInputs] = useState({}); @@ -98,6 +109,11 @@ const SystemSetting = () => { useState(false); const [linuxDOOAuthEnabled, setLinuxDOOAuthEnabled] = useState(false); const [emailToAdd, setEmailToAdd] = useState(''); + const [domainFilterMode, setDomainFilterMode] = useState(true); + const [ipFilterMode, setIpFilterMode] = useState(true); + const [domainList, setDomainList] = useState([]); + const [ipList, setIpList] = useState([]); + const [allowedPorts, setAllowedPorts] = useState([]); const getOptions = async () => { setLoading(true); @@ -113,6 +129,37 @@ const SystemSetting = () => { case 'EmailDomainWhitelist': setEmailDomainWhitelist(item.value ? item.value.split(',') : []); break; + case 'fetch_setting.allow_private_ip': + case 'fetch_setting.enable_ssrf_protection': + case 'fetch_setting.domain_filter_mode': + case 'fetch_setting.ip_filter_mode': + case 'fetch_setting.apply_ip_filter_for_domain': + item.value = toBoolean(item.value); + break; + case 'fetch_setting.domain_list': + try { + const domains = item.value ? JSON.parse(item.value) : []; + setDomainList(Array.isArray(domains) ? domains : []); + } catch (e) { + setDomainList([]); + } + break; + case 'fetch_setting.ip_list': + try { + const ips = item.value ? JSON.parse(item.value) : []; + setIpList(Array.isArray(ips) ? ips : []); + } catch (e) { + setIpList([]); + } + break; + case 'fetch_setting.allowed_ports': + try { + const ports = item.value ? JSON.parse(item.value) : []; + setAllowedPorts(Array.isArray(ports) ? ports : []); + } catch (e) { + setAllowedPorts(['80', '443', '8080', '8443']); + } + break; case 'PasswordLoginEnabled': case 'PasswordRegisterEnabled': case 'EmailVerificationEnabled': @@ -140,6 +187,13 @@ const SystemSetting = () => { }); setInputs(newInputs); setOriginInputs(newInputs); + // 同步模式布尔到本地状态 + if (typeof newInputs['fetch_setting.domain_filter_mode'] !== 'undefined') { + setDomainFilterMode(!!newInputs['fetch_setting.domain_filter_mode']); + } + if (typeof newInputs['fetch_setting.ip_filter_mode'] !== 'undefined') { + setIpFilterMode(!!newInputs['fetch_setting.ip_filter_mode']); + } if (formApiRef.current) { formApiRef.current.setValues(newInputs); } @@ -276,6 +330,46 @@ const SystemSetting = () => { } }; + const submitSSRF = async () => { + const options = []; + + // 处理域名过滤模式与列表 + options.push({ + key: 'fetch_setting.domain_filter_mode', + value: domainFilterMode, + }); + if (Array.isArray(domainList)) { + options.push({ + key: 'fetch_setting.domain_list', + value: JSON.stringify(domainList), + }); + } + + // 处理IP过滤模式与列表 + options.push({ + key: 'fetch_setting.ip_filter_mode', + value: ipFilterMode, + }); + if (Array.isArray(ipList)) { + options.push({ + key: 'fetch_setting.ip_list', + value: JSON.stringify(ipList), + }); + } + + // 处理端口配置 + if (Array.isArray(allowedPorts)) { + options.push({ + key: 'fetch_setting.allowed_ports', + value: JSON.stringify(allowedPorts), + }); + } + + if (options.length > 0) { + await updateOptions(options); + } + }; + const handleAddEmail = () => { if (emailToAdd && emailToAdd.trim() !== '') { const domain = emailToAdd.trim(); @@ -587,6 +681,179 @@ const SystemSetting = () => { + + + + {t('配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全')} + + + + + handleCheckboxChange('fetch_setting.enable_ssrf_protection', e) + } + > + {t('启用SSRF防护(推荐开启以保护服务器安全)')} + + + + + + + + handleCheckboxChange('fetch_setting.allow_private_ip', e) + } + > + {t('允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)')} + + + + + + + + handleCheckboxChange('fetch_setting.apply_ip_filter_for_domain', e) + } + style={{ marginBottom: 8 }} + > + {t('对域名启用 IP 过滤(实验性)')} + + + {t(domainFilterMode ? '域名白名单' : '域名黑名单')} + + + {t('支持通配符格式,如:example.com, *.api.example.com')} + + { + const selected = val && val.target ? val.target.value : val; + const isWhitelist = selected === 'whitelist'; + setDomainFilterMode(isWhitelist); + setInputs(prev => ({ + ...prev, + 'fetch_setting.domain_filter_mode': isWhitelist, + })); + }} + style={{ marginBottom: 8 }} + > + {t('白名单')} + {t('黑名单')} + + { + setDomainList(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.domain_list': value + })); + }} + placeholder={t('输入域名后回车,如:example.com')} + style={{ width: '100%' }} + /> + + + + + + + {t(ipFilterMode ? 'IP白名单' : 'IP黑名单')} + + + {t('支持CIDR格式,如:8.8.8.8, 192.168.1.0/24')} + + { + const selected = val && val.target ? val.target.value : val; + const isWhitelist = selected === 'whitelist'; + setIpFilterMode(isWhitelist); + setInputs(prev => ({ + ...prev, + 'fetch_setting.ip_filter_mode': isWhitelist, + })); + }} + style={{ marginBottom: 8 }} + > + {t('白名单')} + {t('黑名单')} + + { + setIpList(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.ip_list': value + })); + }} + placeholder={t('输入IP地址后回车,如:8.8.8.8')} + style={{ width: '100%' }} + /> + + + + + + {t('允许的端口')} + + {t('支持单个端口和端口范围,如:80, 443, 8000-8999')} + + { + setAllowedPorts(value); + // 触发Form的onChange事件 + setInputs(prev => ({ + ...prev, + 'fetch_setting.allowed_ports': value + })); + }} + placeholder={t('输入端口后回车,如:80 或 8000-8999')} + style={{ width: '100%' }} + /> + + {t('端口配置详细说明')} + + + + + + + + { return (checked) => { @@ -132,6 +136,9 @@ const NotificationSettings = ({ }); if (res.data.success) { showSuccess(t('侧边栏设置保存成功')); + + // 刷新useSidebar钩子中的用户配置,实现实时更新 + await refreshUserConfig(); } else { showError(res.data.message); } @@ -334,7 +341,7 @@ const NotificationSettings = ({ loading={sidebarLoading} className='!rounded-lg' > - {t('保存边栏设置')} + {t('保存设置')} ) : ( diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 766c1715..b63c7dd4 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -35,8 +35,9 @@ import { Sparkles, } from 'lucide-react'; import { - TASK_ACTION_GENERATE, - TASK_ACTION_TEXT_GENERATE, + TASK_ACTION_FIRST_TAIL_GENERATE, + TASK_ACTION_GENERATE, TASK_ACTION_REFERENCE_GENERATE, + TASK_ACTION_TEXT_GENERATE } from '../../../constants/common.constant'; import { CHANNEL_OPTIONS } from '../../../constants/channel.constants'; @@ -111,6 +112,18 @@ const renderType = (type, t) => { {t('文生视频')} ); + case TASK_ACTION_FIRST_TAIL_GENERATE: + return ( + }> + {t('首尾生视频')} + + ); + case TASK_ACTION_REFERENCE_GENERATE: + return ( + }> + {t('参照生视频')} + + ); default: return ( }> @@ -343,7 +356,9 @@ export const getTaskLogsColumns = ({ // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 const isVideoTask = record.action === TASK_ACTION_GENERATE || - record.action === TASK_ACTION_TEXT_GENERATE; + record.action === TASK_ACTION_TEXT_GENERATE || + record.action === TASK_ACTION_FIRST_TAIL_GENERATE || + record.action === TASK_ACTION_REFERENCE_GENERATE; const isSuccess = record.status === 'SUCCESS'; const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); if (isSuccess && isVideoTask && isUrl) { diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js index 277bb9a5..57fbbbde 100644 --- a/web/src/constants/common.constant.js +++ b/web/src/constants/common.constant.js @@ -40,3 +40,5 @@ export const API_ENDPOINTS = [ export const TASK_ACTION_GENERATE = 'generate'; export const TASK_ACTION_TEXT_GENERATE = 'textGenerate'; +export const TASK_ACTION_FIRST_TAIL_GENERATE = 'firstTailGenerate'; +export const TASK_ACTION_REFERENCE_GENERATE = 'referenceGenerate'; diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index 65332701..c19e2849 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -1027,6 +1027,8 @@ export function renderModelPrice( audioInputSeperatePrice = false, audioInputTokens = 0, audioInputPrice = 0, + imageGenerationCall = false, + imageGenerationCallPrice = 0, ) { const { ratio: effectiveGroupRatio, label: ratioLabel } = getEffectiveRatio( groupRatio, @@ -1069,7 +1071,8 @@ export function renderModelPrice( (audioInputTokens / 1000000) * audioInputPrice * groupRatio + (completionTokens / 1000000) * completionRatioPrice * groupRatio + (webSearchCallCount / 1000) * webSearchPrice * groupRatio + - (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio; + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio + + (imageGenerationCallPrice * groupRatio); return ( <> @@ -1131,7 +1134,13 @@ export function renderModelPrice( })}

)} -

+ {imageGenerationCall && imageGenerationCallPrice > 0 && ( +

+ {i18next.t('图片生成调用:${{price}} / 1次', { + price: imageGenerationCallPrice, + })} +

+ )}

{(() => { // 构建输入部分描述 @@ -1211,6 +1220,16 @@ export function renderModelPrice( }, ) : '', + imageGenerationCall && imageGenerationCallPrice > 0 + ? i18next.t( + ' + 图片生成调用 ${{price}} / 1次 * {{ratioType}} {{ratio}}', + { + price: imageGenerationCallPrice, + ratio: groupRatio, + ratioType: ratioLabel, + }, + ) + : '', ].join(''); return i18next.t( diff --git a/web/src/hooks/common/useSidebar.js b/web/src/hooks/common/useSidebar.js index 5dce44f9..13d76fd8 100644 --- a/web/src/hooks/common/useSidebar.js +++ b/web/src/hooks/common/useSidebar.js @@ -21,6 +21,10 @@ import { useState, useEffect, useMemo, useContext } from 'react'; import { StatusContext } from '../../context/Status'; import { API } from '../../helpers'; +// 创建一个全局事件系统来同步所有useSidebar实例 +const sidebarEventTarget = new EventTarget(); +const SIDEBAR_REFRESH_EVENT = 'sidebar-refresh'; + export const useSidebar = () => { const [statusState] = useContext(StatusContext); const [userConfig, setUserConfig] = useState(null); @@ -124,9 +128,12 @@ export const useSidebar = () => { // 刷新用户配置的方法(供外部调用) const refreshUserConfig = async () => { - if (Object.keys(adminConfig).length > 0) { + if (Object.keys(adminConfig).length > 0) { await loadUserConfig(); } + + // 触发全局刷新事件,通知所有useSidebar实例更新 + sidebarEventTarget.dispatchEvent(new CustomEvent(SIDEBAR_REFRESH_EVENT)); }; // 加载用户配置 @@ -137,6 +144,21 @@ export const useSidebar = () => { } }, [adminConfig]); + // 监听全局刷新事件 + useEffect(() => { + const handleRefresh = () => { + if (Object.keys(adminConfig).length > 0) { + loadUserConfig(); + } + }; + + sidebarEventTarget.addEventListener(SIDEBAR_REFRESH_EVENT, handleRefresh); + + return () => { + sidebarEventTarget.removeEventListener(SIDEBAR_REFRESH_EVENT, handleRefresh); + }; + }, [adminConfig]); + // 计算最终的显示配置 const finalConfig = useMemo(() => { const result = {}; diff --git a/web/src/hooks/dashboard/useDashboardStats.jsx b/web/src/hooks/dashboard/useDashboardStats.jsx index aa9677a5..dbf3b67e 100644 --- a/web/src/hooks/dashboard/useDashboardStats.jsx +++ b/web/src/hooks/dashboard/useDashboardStats.jsx @@ -102,7 +102,7 @@ export const useDashboardStats = ( }, { title: t('统计Tokens'), - value: isNaN(consumeTokens) ? 0 : consumeTokens, + value: isNaN(consumeTokens) ? 0 : consumeTokens.toLocaleString(), icon: , avatarColor: 'pink', trendData: trendData.tokens, diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index 81f3f539..d434e733 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -447,6 +447,8 @@ export const useLogsData = () => { other?.audio_input_seperate_price || false, other?.audio_input_token_count || 0, other?.audio_input_price || 0, + other?.image_generation_call || false, + other?.image_generation_call_price || 0, ); } expandDataLocal.push({ diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index a527b91c..a305b0a9 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1999,6 +1999,16 @@ "查看渠道密钥": "View channel key", "渠道密钥信息": "Channel key information", "密钥获取成功": "Key acquisition successful", + "模型补全倍率(仅对自定义模型有效)": "Model completion ratio (only effective for custom models)", + "图片倍率": "Image ratio", + "音频倍率": "Audio ratio", + "音频补全倍率": "Audio completion ratio", + "图片输入相关的倍率设置,键为模型名称,值为倍率": "Image input related ratio settings, key is model name, value is ratio", + "音频输入相关的倍率设置,键为模型名称,值为倍率": "Audio input related ratio settings, key is model name, value is ratio", + "音频输出补全相关的倍率设置,键为模型名称,值为倍率": "Audio output completion related ratio settings, key is model name, value is ratio", + "为一个 JSON 文本,键为模型名称,值为倍率,例如:{\"gpt-image-1\": 2}": "A JSON text with model name as key and ratio as value, e.g.: {\"gpt-image-1\": 2}", + "为一个 JSON 文本,键为模型名称,值为倍率,例如:{\"gpt-4o-audio-preview\": 16}": "A JSON text with model name as key and ratio as value, e.g.: {\"gpt-4o-audio-preview\": 16}", + "为一个 JSON 文本,键为模型名称,值为倍率,例如:{\"gpt-4o-realtime\": 2}": "A JSON text with model name as key and ratio as value, e.g.: {\"gpt-4o-realtime\": 2}", "顶栏管理": "Header Management", "控制顶栏模块显示状态,全局生效": "Control header module display status, global effect", "用户主页,展示系统信息": "User homepage, displaying system information", @@ -2058,7 +2068,7 @@ "需要登录访问": "Require Login", "开启后未登录用户无法访问模型广场": "When enabled, unauthenticated users cannot access the model marketplace", "参与官方同步": "Participate in official sync", - "关闭后,此模型将不会被“同步官方”自动覆盖或创建": "When turned off, this model will be skipped by Sync official (no auto create/overwrite)", + "关闭后,此模型将不会被\"同步官方\"自动覆盖或创建": "When turned off, this model will be skipped by Sync official (no auto create/overwrite)", "同步": "Sync", "同步向导": "Sync Wizard", "选择方式": "Select method", @@ -2089,5 +2099,31 @@ "近 7 天": "Last 7 Days", "本周": "This Week", "本月": "This Month", - "近 30 天": "Last 30 Days" + "近 30 天": "Last 30 Days", + "代理设置": "Proxy Settings", + "更新Worker设置": "Update Worker Settings", + "SSRF防护设置": "SSRF Protection Settings", + "配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全": "Configure Server-Side Request Forgery (SSRF) protection to secure internal network resources", + "SSRF防护详细说明": "SSRF protection prevents malicious users from using your server to access internal network resources. Configure whitelists for trusted domains/IPs and restrict allowed ports. Applies to file downloads, webhooks, and notifications.", + "启用SSRF防护(推荐开启以保护服务器安全)": "Enable SSRF Protection (Recommended for server security)", + "SSRF防护开关详细说明": "Master switch controls whether SSRF protection is enabled. When disabled, all SSRF checks are bypassed, allowing access to any URL. ⚠️ Only disable this feature in completely trusted environments.", + "允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)": "Allow access to private IP addresses (127.0.0.1, 192.168.x.x and other internal addresses)", + "私有IP访问详细说明": "⚠️ Security Warning: Enabling this allows access to internal network resources (localhost, private networks). Only enable if you need to access internal services and understand the security implications.", + "域名白名单": "Domain Whitelist", + "支持通配符格式,如:example.com, *.api.example.com": "Supports wildcard format, e.g.: example.com, *.api.example.com", + "域名白名单详细说明": "Whitelisted domains bypass all SSRF checks and are allowed direct access. Supports exact domains (example.com) or wildcards (*.api.example.com) for subdomains. When whitelist is empty, all domains go through SSRF validation.", + "输入域名后回车,如:example.com": "Enter domain and press Enter, e.g.: example.com", + "支持CIDR格式,如:8.8.8.8, 192.168.1.0/24": "Supports CIDR format, e.g.: 8.8.8.8, 192.168.1.0/24", + "IP白名单详细说明": "Controls which IP addresses are allowed access. Use single IPs (8.8.8.8) or CIDR notation (192.168.1.0/24). Empty whitelist allows all IPs (subject to private IP settings), non-empty whitelist only allows listed IPs.", + "输入IP地址后回车,如:8.8.8.8": "Enter IP address and press Enter, e.g.: 8.8.8.8", + "允许的端口": "Allowed Ports", + "支持单个端口和端口范围,如:80, 443, 8000-8999": "Supports single ports and port ranges, e.g.: 80, 443, 8000-8999", + "端口配置详细说明": "Restrict external requests to specific ports. Use single ports (80, 443) or ranges (8000-8999). Empty list allows all ports. Default includes common web ports.", + "输入端口后回车,如:80 或 8000-8999": "Enter port and press Enter, e.g.: 80 or 8000-8999", + "更新SSRF防护设置": "Update SSRF Protection Settings", + "对域名启用 IP 过滤(实验性)": "Enable IP filtering for domains (experimental)", + "域名IP过滤详细说明": "⚠️ This is an experimental option. A domain may resolve to multiple IPv4/IPv6 addresses. If enabled, ensure the IP filter list covers these addresses, otherwise access may fail.", + "域名黑名单": "Domain Blacklist", + "白名单": "Whitelist", + "黑名单": "Blacklist" } diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index 5c7904fc..95fa0641 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -9,5 +9,28 @@ "语言": "语言", "展开侧边栏": "展开侧边栏", "关闭侧边栏": "关闭侧边栏", - "注销成功!": "注销成功!" + "注销成功!": "注销成功!", + "代理设置": "代理设置", + "更新Worker设置": "更新Worker设置", + "SSRF防护设置": "SSRF防护设置", + "配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全": "配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全", + "SSRF防护详细说明": "SSRF防护可防止恶意用户利用您的服务器访问内网资源。您可以配置受信任域名/IP的白名单,并限制允许的端口。适用于文件下载、Webhook回调和通知功能。", + "启用SSRF防护(推荐开启以保护服务器安全)": "启用SSRF防护(推荐开启以保护服务器安全)", + "SSRF防护开关详细说明": "总开关控制是否启用SSRF防护功能。关闭后将跳过所有SSRF检查,允许访问任意URL。⚠️ 仅在完全信任环境中关闭此功能。", + "允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)": "允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)", + "私有IP访问详细说明": "⚠️ 安全警告:启用此选项将允许访问内网资源(本地主机、私有网络)。仅在需要访问内部服务且了解安全风险的情况下启用。", + "域名白名单": "域名白名单", + "支持通配符格式,如:example.com, *.api.example.com": "支持通配符格式,如:example.com, *.api.example.com", + "域名白名单详细说明": "白名单中的域名将绕过所有SSRF检查,直接允许访问。支持精确域名(example.com)或通配符(*.api.example.com)匹配子域名。白名单为空时,所有域名都需要通过SSRF检查。", + "输入域名后回车,如:example.com": "输入域名后回车,如:example.com", + "IP白名单": "IP白名单", + "支持CIDR格式,如:8.8.8.8, 192.168.1.0/24": "支持CIDR格式,如:8.8.8.8, 192.168.1.0/24", + "IP白名单详细说明": "控制允许访问的IP地址。支持单个IP(8.8.8.8)或CIDR网段(192.168.1.0/24)。空白名单允许所有IP(但仍受私有IP设置限制),非空白名单仅允许列表中的IP访问。", + "输入IP地址后回车,如:8.8.8.8": "输入IP地址后回车,如:8.8.8.8", + "允许的端口": "允许的端口", + "支持单个端口和端口范围,如:80, 443, 8000-8999": "支持单个端口和端口范围,如:80, 443, 8000-8999", + "端口配置详细说明": "限制外部请求只能访问指定端口。支持单个端口(80, 443)或端口范围(8000-8999)。空列表允许所有端口。默认包含常用Web端口。", + "输入端口后回车,如:80 或 8000-8999": "输入端口后回车,如:80 或 8000-8999", + "更新SSRF防护设置": "更新SSRF防护设置", + "域名IP过滤详细说明": "⚠️此功能为实验性选项,域名可能解析到多个 IPv4/IPv6 地址,若开启,请确保 IP 过滤列表覆盖这些地址,否则可能导致访问失败。" } diff --git a/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx b/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx index 2462a35a..ed982edc 100644 --- a/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx +++ b/web/src/pages/Setting/Ratio/ModelRatioSettings.jsx @@ -44,6 +44,9 @@ export default function ModelRatioSettings(props) { ModelRatio: '', CacheRatio: '', CompletionRatio: '', + ImageRatio: '', + AudioRatio: '', + AudioCompletionRatio: '', ExposeRatioEnabled: false, }); const refForm = useRef(); @@ -219,6 +222,72 @@ export default function ModelRatioSettings(props) { /> + + + verifyJSON(value), + message: '不是合法的 JSON 字符串', + }, + ]} + onChange={(value) => + setInputs({ ...inputs, ImageRatio: value }) + } + /> + + + + + verifyJSON(value), + message: '不是合法的 JSON 字符串', + }, + ]} + onChange={(value) => + setInputs({ ...inputs, AudioRatio: value }) + } + /> + + + + + verifyJSON(value), + message: '不是合法的 JSON 字符串', + }, + ]} + onChange={(value) => + setInputs({ ...inputs, AudioCompletionRatio: value }) + } + /> + +