diff --git a/common/ip.go b/common/ip.go
new file mode 100644
index 00000000..bfb64ee7
--- /dev/null
+++ b/common/ip.go
@@ -0,0 +1,22 @@
+package common
+
+import "net"
+
+func IsPrivateIP(ip net.IP) bool {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+
+ private := []net.IPNet{
+ {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
+ {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
+ {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
+ }
+
+ for _, privateNet := range private {
+ if privateNet.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go
new file mode 100644
index 00000000..6f7d289f
--- /dev/null
+++ b/common/ssrf_protection.go
@@ -0,0 +1,327 @@
+package common
+
+import (
+ "fmt"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+)
+
+// SSRFProtection SSRF防护配置
+type SSRFProtection struct {
+ AllowPrivateIp bool
+ DomainFilterMode bool // true: 白名单, false: 黑名单
+ DomainList []string // domain format, e.g. example.com, *.example.com
+ IpFilterMode bool // true: 白名单, false: 黑名单
+ IpList []string // CIDR or single IP
+ AllowedPorts []int // 允许的端口范围
+ ApplyIPFilterForDomain bool // 对域名启用IP过滤
+}
+
+// DefaultSSRFProtection 默认SSRF防护配置
+var DefaultSSRFProtection = &SSRFProtection{
+ AllowPrivateIp: false,
+ DomainFilterMode: true,
+ DomainList: []string{},
+ IpFilterMode: true,
+ IpList: []string{},
+ AllowedPorts: []int{},
+}
+
+// isPrivateIP 检查IP是否为私有地址
+func isPrivateIP(ip net.IP) bool {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+
+ // 检查私有网段
+ private := []net.IPNet{
+ {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
+ {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
+ {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
+ {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
+ {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
+ {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
+ {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
+ }
+
+ for _, privateNet := range private {
+ if privateNet.Contains(ip) {
+ return true
+ }
+ }
+
+ // 检查IPv6私有地址
+ if ip.To4() == nil {
+ // IPv6 loopback
+ if ip.Equal(net.IPv6loopback) {
+ return true
+ }
+ // IPv6 link-local
+ if strings.HasPrefix(ip.String(), "fe80:") {
+ return true
+ }
+ // IPv6 unique local
+ if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
+ return true
+ }
+ }
+
+ return false
+}
+
+// parsePortRanges 解析端口范围配置
+// 支持格式: "80", "443", "8000-9000"
+func parsePortRanges(portConfigs []string) ([]int, error) {
+ var ports []int
+
+ for _, config := range portConfigs {
+ config = strings.TrimSpace(config)
+ if config == "" {
+ continue
+ }
+
+ if strings.Contains(config, "-") {
+ // 处理端口范围 "8000-9000"
+ parts := strings.Split(config, "-")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid port range format: %s", config)
+ }
+
+ startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
+ if err != nil {
+ return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
+ }
+
+ endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
+ if err != nil {
+ return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
+ }
+
+ if startPort > endPort {
+ return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
+ }
+
+ if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
+ return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
+ }
+
+ // 添加范围内的所有端口
+ for port := startPort; port <= endPort; port++ {
+ ports = append(ports, port)
+ }
+ } else {
+ // 处理单个端口 "80"
+ port, err := strconv.Atoi(config)
+ if err != nil {
+ return nil, fmt.Errorf("invalid port number: %s", config)
+ }
+
+ if port < 1 || port > 65535 {
+ return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
+ }
+
+ ports = append(ports, port)
+ }
+ }
+
+ return ports, nil
+}
+
+// isAllowedPort 检查端口是否被允许
+func (p *SSRFProtection) isAllowedPort(port int) bool {
+ if len(p.AllowedPorts) == 0 {
+ return true // 如果没有配置端口限制,则允许所有端口
+ }
+
+ for _, allowedPort := range p.AllowedPorts {
+ if port == allowedPort {
+ return true
+ }
+ }
+ return false
+}
+
+// isDomainWhitelisted 检查域名是否在白名单中
+func isDomainListed(domain string, list []string) bool {
+ if len(list) == 0 {
+ return false
+ }
+
+ domain = strings.ToLower(domain)
+ for _, item := range list {
+ item = strings.ToLower(strings.TrimSpace(item))
+ if item == "" {
+ continue
+ }
+ // 精确匹配
+ if domain == item {
+ return true
+ }
+ // 通配符匹配 (*.example.com)
+ if strings.HasPrefix(item, "*.") {
+ suffix := strings.TrimPrefix(item, "*.")
+ if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func (p *SSRFProtection) isDomainAllowed(domain string) bool {
+ listed := isDomainListed(domain, p.DomainList)
+ if p.DomainFilterMode { // 白名单
+ return listed
+ }
+ // 黑名单
+ return !listed
+}
+
+// isIPWhitelisted 检查IP是否在白名单中
+
+func isIPListed(ip net.IP, list []string) bool {
+ if len(list) == 0 {
+ return false
+ }
+
+ for _, whitelistCIDR := range list {
+ _, network, err := net.ParseCIDR(whitelistCIDR)
+ if err != nil {
+ // 尝试作为单个IP处理
+ if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
+ if ip.Equal(whitelistIP) {
+ return true
+ }
+ }
+ continue
+ }
+
+ if network.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+// IsIPAccessAllowed 检查IP是否允许访问
+func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
+ // 私有IP限制
+ if isPrivateIP(ip) && !p.AllowPrivateIp {
+ return false
+ }
+
+ listed := isIPListed(ip, p.IpList)
+ if p.IpFilterMode { // 白名单
+ return listed
+ }
+ // 黑名单
+ return !listed
+}
+
+// ValidateURL 验证URL是否安全
+func (p *SSRFProtection) ValidateURL(urlStr string) error {
+ // 解析URL
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return fmt.Errorf("invalid URL format: %v", err)
+ }
+
+ // 只允许HTTP/HTTPS协议
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
+ }
+
+ // 解析主机和端口
+ host, portStr, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ // 没有端口,使用默认端口
+ host = u.Hostname()
+ if u.Scheme == "https" {
+ portStr = "443"
+ } else {
+ portStr = "80"
+ }
+ }
+
+ // 验证端口
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("invalid port: %s", portStr)
+ }
+
+ if !p.isAllowedPort(port) {
+ return fmt.Errorf("port %d is not allowed", port)
+ }
+
+ // 如果 host 是 IP,则跳过域名检查
+ if ip := net.ParseIP(host); ip != nil {
+ if !p.IsIPAccessAllowed(ip) {
+ if isPrivateIP(ip) {
+ return fmt.Errorf("private IP address not allowed: %s", ip.String())
+ }
+ if p.IpFilterMode {
+ return fmt.Errorf("ip not in whitelist: %s", ip.String())
+ }
+ return fmt.Errorf("ip in blacklist: %s", ip.String())
+ }
+ return nil
+ }
+
+ // 先进行域名过滤
+ if !p.isDomainAllowed(host) {
+ if p.DomainFilterMode {
+ return fmt.Errorf("domain not in whitelist: %s", host)
+ }
+ return fmt.Errorf("domain in blacklist: %s", host)
+ }
+
+ // 若未启用对域名应用IP过滤,则到此通过
+ if !p.ApplyIPFilterForDomain {
+ return nil
+ }
+
+ // 解析域名对应IP并检查
+ ips, err := net.LookupIP(host)
+ if err != nil {
+ return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
+ }
+ for _, ip := range ips {
+ if !p.IsIPAccessAllowed(ip) {
+ if isPrivateIP(ip) && !p.AllowPrivateIp {
+ return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
+ }
+ if p.IpFilterMode {
+ return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
+ }
+ return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
+ }
+ }
+ return nil
+}
+
+// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
+func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
+ // 如果SSRF防护被禁用,直接返回成功
+ if !enableSSRFProtection {
+ return nil
+ }
+
+ // 解析端口范围配置
+ allowedPortInts, err := parsePortRanges(allowedPorts)
+ if err != nil {
+ return fmt.Errorf("request reject - invalid port configuration: %v", err)
+ }
+
+ protection := &SSRFProtection{
+ AllowPrivateIp: allowPrivateIp,
+ DomainFilterMode: domainFilterMode,
+ DomainList: domainList,
+ IpFilterMode: ipFilterMode,
+ IpList: ipList,
+ AllowedPorts: allowedPortInts,
+ ApplyIPFilterForDomain: applyIPFilterForDomain,
+ }
+ return protection.ValidateURL(urlStr)
+}
diff --git a/constant/task.go b/constant/task.go
index 21790145..e174fd60 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -11,8 +11,10 @@ const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
- TaskActionGenerate = "generate"
- TaskActionTextGenerate = "textGenerate"
+ TaskActionGenerate = "generate"
+ TaskActionTextGenerate = "textGenerate"
+ TaskActionFirstTailGenerate = "firstTailGenerate"
+ TaskActionReferenceGenerate = "referenceGenerate"
)
var SunoModel2Action = map[string]string{
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 5a668c48..9ea6eed7 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult {
requestPath = "/v1/embeddings" // 修改请求路径
}
+ // VolcEngine 图像生成模型
+ if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+ requestPath = "/v1/images/generations"
+ }
+
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: requestPath}, // 使用动态路径
@@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
}
+ // 重新检查模型类型并更新请求路径
+ if strings.Contains(strings.ToLower(testModel), "embedding") ||
+ strings.HasPrefix(testModel, "m3e") ||
+ strings.Contains(testModel, "bge-") ||
+ strings.Contains(testModel, "embed") ||
+ channel.Type == constant.ChannelTypeMokaAI {
+ requestPath = "/v1/embeddings"
+ c.Request.URL.Path = requestPath
+ }
+
+ if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+ requestPath = "/v1/images/generations"
+ c.Request.URL.Path = requestPath
+ }
+
cache, err := model.GetUserCache(1)
if err != nil {
return testResult{
@@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
if c.Request.URL.Path == "/v1/embeddings" {
relayFormat = types.RelayFormatEmbedding
}
+ if c.Request.URL.Path == "/v1/images/generations" {
+ relayFormat = types.RelayFormatOpenAIImage
+ }
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
@@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
// 调用专门用于 Embedding 的转换函数
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
+ } else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+ // 创建一个 ImageRequest
+ prompt := "cat"
+ if request.Prompt != nil {
+ if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
+ prompt = promptStr
+ }
+ }
+ imageRequest := dto.ImageRequest{
+ Prompt: prompt,
+ Model: request.Model,
+ N: uint(request.N),
+ Size: request.Size,
+ }
+ // 调用专门用于图像生成的转换函数
+ convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
} else {
// 对其他所有请求类型(如 Chat),保持原有逻辑
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
diff --git a/controller/channel.go b/controller/channel.go
index 403eb04c..17154ab0 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -501,9 +501,10 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
}
type AddChannelRequest struct {
- Mode string `json:"mode"`
- MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
- Channel *model.Channel `json:"channel"`
+ Mode string `json:"mode"`
+ MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+ BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
+ Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
@@ -616,6 +617,13 @@ func AddChannel(c *gin.Context) {
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
+ if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
+ keyPrefix := localChannel.Key
+ if len(localChannel.Key) > 8 {
+ keyPrefix = localChannel.Key[:8]
+ }
+ localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
+ }
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
diff --git a/controller/option.go b/controller/option.go
index e5f2b75b..7d1c676f 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -128,6 +128,33 @@ func UpdateOption(c *gin.Context) {
})
return
}
+ case "ImageRatio":
+ err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "图片倍率设置失败: " + err.Error(),
+ })
+ return
+ }
+ case "AudioRatio":
+ err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "音频倍率设置失败: " + err.Error(),
+ })
+ return
+ }
+ case "AudioCompletionRatio":
+ err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "音频补全倍率设置失败: " + err.Error(),
+ })
+ return
+ }
case "ModelRequestRateLimitGroup":
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
if err != nil {
diff --git a/controller/setup.go b/controller/setup.go
index 44a7b3a7..3ae255e9 100644
--- a/controller/setup.go
+++ b/controller/setup.go
@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
func PostSetup(c *gin.Context) {
// Check if setup is already completed
if constant.Setup {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "系统已经初始化完成",
})
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
var req SetupRequest
err := c.ShouldBindJSON(&req)
if err != nil {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "请求参数有误",
})
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
if !rootExists {
// Validate username length: max 12 characters to align with model.User validation
if len(req.Username) > 12 {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "用户名长度不能超过12个字符",
})
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
}
// Validate password
if req.Password != req.ConfirmPassword {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "两次输入的密码不一致",
})
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
}
if len(req.Password) < 8 {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "密码长度至少为8个字符",
})
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
// Create root user
hashedPassword, err := common.Password2Hash(req.Password)
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "系统错误: " + err.Error(),
})
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
}
err = model.DB.Create(&rootUser).Error
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "创建管理员账号失败: " + err.Error(),
})
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
// Save operation modes to database for persistence
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "保存自用模式设置失败: " + err.Error(),
})
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "保存演示站点模式设置失败: " + err.Error(),
})
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
}
err = model.DB.Create(&setup).Error
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "系统初始化失败: " + err.Error(),
})
diff --git a/dto/openai_request.go b/dto/openai_request.go
index cd05a63c..191fa638 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -772,11 +772,12 @@ type OpenAIResponsesRequest struct {
Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
- ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
+ ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
- Store bool `json:"store,omitempty"`
+ Store json.RawMessage `json:"store,omitempty"`
+ PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"`
diff --git a/dto/openai_response.go b/dto/openai_response.go
index 966748cb..6353c15f 100644
--- a/dto/openai_response.go
+++ b/dto/openai_response.go
@@ -6,6 +6,10 @@ import (
"one-api/types"
)
+const (
+ ResponsesOutputTypeImageGenerationCall = "image_generation_call"
+)
+
type SimpleResponse struct {
Usage `json:"usage"`
Error any `json:"error"`
@@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
+func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
+ if len(o.Output) == 0 {
+ return false
+ }
+ for _, output := range o.Output {
+ if output.Type == ResponsesOutputTypeImageGenerationCall {
+ return true
+ }
+ }
+ return false
+}
+
+func (o *OpenAIResponsesResponse) GetQuality() string {
+ if len(o.Output) == 0 {
+ return ""
+ }
+ for _, output := range o.Output {
+ if output.Type == ResponsesOutputTypeImageGenerationCall {
+ return output.Quality
+ }
+ }
+ return ""
+}
+
+func (o *OpenAIResponsesResponse) GetSize() string {
+ if len(o.Output) == 0 {
+ return ""
+ }
+ for _, output := range o.Output {
+ if output.Type == ResponsesOutputTypeImageGenerationCall {
+ return output.Size
+ }
+ }
+ return ""
+}
+
type IncompleteDetails struct {
Reasoning string `json:"reasoning"`
}
@@ -283,6 +323,8 @@ type ResponsesOutput struct {
Status string `json:"status"`
Role string `json:"role"`
Content []ResponsesOutputContent `json:"content"`
+ Quality string `json:"quality"`
+ Size string `json:"size"`
}
type ResponsesOutputContent struct {
diff --git a/model/option.go b/model/option.go
index fefee4e7..ceecff65 100644
--- a/model/option.go
+++ b/model/option.go
@@ -112,6 +112,9 @@ func InitOptionMap() {
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
+ common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
+ common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
+ common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -397,6 +400,12 @@ func updateOptionMap(key string, value string) (err error) {
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = ratio_setting.UpdateCacheRatioByJSONString(value)
+ case "ImageRatio":
+ err = ratio_setting.UpdateImageRatioByJSONString(value)
+ case "AudioRatio":
+ err = ratio_setting.UpdateAudioRatioByJSONString(value)
+ case "AudioCompletionRatio":
+ err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go
index 17d732ab..962f8794 100644
--- a/relay/channel/deepseek/adaptor.go
+++ b/relay/channel/deepseek/adaptor.go
@@ -3,17 +3,17 @@ package deepseek
import (
"errors"
"fmt"
+ "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
+ "one-api/relay/channel/claude"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
-
- "github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
+ adaptor := claude.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
@@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.ChannelBaseUrl
- if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
- fimBaseUrl += "/beta"
- }
- switch info.RelayMode {
- case constant.RelayModeCompletions:
- return fmt.Sprintf("%s/completions", fimBaseUrl), nil
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
default:
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
+ fimBaseUrl += "/beta"
+ }
+ switch info.RelayMode {
+ case constant.RelayModeCompletions:
+ return fmt.Sprintf("%s/completions", fimBaseUrl), nil
+ default:
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ }
}
}
@@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ if info.IsStream {
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ } else {
+ return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
+ }
+ default:
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
}
- return
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 4968f78f..57542aa5 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
- if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
- strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
+ if strings.Contains(info.RequestURLPath, ":embedContent") ||
+ strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream {
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index eb4afbae..199c8466 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -23,6 +23,7 @@ import (
"github.com/gin-gonic/gin"
)
+// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
var geminiSupportedMimeTypes = map[string]bool{
"application/pdf": true,
"audio/mpeg": true,
@@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
+ "image/webp": true,
"text/plain": true,
"video/mov": true,
"video/mpeg": true,
diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go
index e290c239..f24976bb 100644
--- a/relay/channel/moonshot/adaptor.go
+++ b/relay/channel/moonshot/adaptor.go
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
+ adaptor := claude.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go
index e188889e..85938a77 100644
--- a/relay/channel/openai/relay_responses.go
+++ b/relay/channel/openai/relay_responses.go
@@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
+ if responsesResponse.HasImageGenerationCall() {
+ c.Set("image_generation_call", true)
+ c.Set("image_generation_call_quality", responsesResponse.GetQuality())
+ c.Set("image_generation_call_size", responsesResponse.GetSize())
+ }
+
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
@@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
- if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
- if streamResponse.Response.Usage.InputTokens != 0 {
- usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+ if streamResponse.Response != nil {
+ if streamResponse.Response.Usage != nil {
+ if streamResponse.Response.Usage.InputTokens != 0 {
+ usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+ }
+ if streamResponse.Response.Usage.OutputTokens != 0 {
+ usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+ }
+ if streamResponse.Response.Usage.TotalTokens != 0 {
+ usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+ }
+ if streamResponse.Response.Usage.InputTokensDetails != nil {
+ usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+ }
}
- if streamResponse.Response.Usage.OutputTokens != 0 {
- usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
- }
- if streamResponse.Response.Usage.TotalTokens != 0 {
- usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
- }
- if streamResponse.Response.Usage.InputTokensDetails != nil {
- usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+ if streamResponse.Response.HasImageGenerationCall() {
+ c.Set("image_generation_call", true)
+ c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
+ c.Set("image_generation_call_size", streamResponse.Response.GetSize())
}
}
case "response.output_text.delta":
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
index b954d7b8..a2545a27 100644
--- a/relay/channel/task/jimeng/adaptor.go
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -94,6 +94,9 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if isNewAPIRelay(info.ApiKey) {
+ return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
+ }
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
}
@@ -101,7 +104,12 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
- return a.signRequest(req, a.accessKey, a.secretKey)
+ if isNewAPIRelay(info.ApiKey) {
+ req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+ } else {
+ return a.signRequest(req, a.accessKey, a.secretKey)
+ }
+ return nil
}
// BuildRequestBody converts request into Jimeng specific format.
@@ -161,6 +169,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
}
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
+ if isNewAPIRelay(key) {
+ uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
+ }
payload := map[string]string{
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
@@ -178,17 +189,20 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
- keyParts := strings.Split(key, "|")
- if len(keyParts) != 2 {
- return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
- }
- accessKey := strings.TrimSpace(keyParts[0])
- secretKey := strings.TrimSpace(keyParts[1])
+ if isNewAPIRelay(key) {
+ req.Header.Set("Authorization", "Bearer "+key)
+ } else {
+ keyParts := strings.Split(key, "|")
+ if len(keyParts) != 2 {
+ return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
- if err := a.signRequest(req, accessKey, secretKey); err != nil {
- return nil, errors.Wrap(err, "sign request failed")
+ if err := a.signRequest(req, accessKey, secretKey); err != nil {
+ return nil, errors.Wrap(err, "sign request failed")
+ }
}
-
return service.GetHttpClient().Do(req)
}
@@ -384,3 +398,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
taskResult.Url = resTask.Data.VideoUrl
return &taskResult, nil
}
+
+func isNewAPIRelay(apiKey string) bool {
+ return strings.HasPrefix(apiKey, "sk-")
+}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
index 13f2af97..fec3396a 100644
--- a/relay/channel/task/kling/adaptor.go
+++ b/relay/channel/task/kling/adaptor.go
@@ -117,6 +117,11 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
+
+ if isNewAPIRelay(info.ApiKey) {
+ return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
+ }
+
return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
@@ -199,6 +204,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
}
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
+ if isNewAPIRelay(key) {
+ url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
+ }
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@@ -304,8 +312,13 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
//}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
-
+ if isNewAPIRelay(apiKey) {
+ return apiKey, nil // new api relay
+ }
keyParts := strings.Split(apiKey, "|")
+ if len(keyParts) != 2 {
+ return "", errors.New("invalid api_key, required format is accessKey|secretKey")
+ }
accessKey := strings.TrimSpace(keyParts[0])
if len(keyParts) == 1 {
return accessKey, nil
@@ -352,3 +365,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
return taskInfo, nil
}
+
+func isNewAPIRelay(apiKey string) bool {
+ return strings.HasPrefix(apiKey, "sk-")
+}
diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go
index a1140d1e..358aef58 100644
--- a/relay/channel/task/vidu/adaptor.go
+++ b/relay/channel/task/vidu/adaptor.go
@@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
- // Use the unified validation method for TaskSubmitReq with image-based action determination
- return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
switch info.Action {
case constant.TaskActionGenerate:
path = "/img2video"
+ case constant.TaskActionFirstTailGenerate:
+ path = "/start-end2video"
+ case constant.TaskActionReferenceGenerate:
+ path = "/reference2video"
default:
path = "/text2video"
}
@@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string {
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- var images []string
- if req.Image != "" {
- images = []string{req.Image}
- }
-
r := requestPayload{
Model: defaultString(req.Model, "viduq1"),
- Images: images,
+ Images: req.Images,
Prompt: req.Prompt,
Duration: defaultInt(req.Duration, 5),
Resolution: defaultString(req.Size, "1080p"),
diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go
index 0af019da..eb88412a 100644
--- a/relay/channel/volcengine/adaptor.go
+++ b/relay/channel/volcengine/adaptor.go
@@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
+ case constant.RelayModeImagesGenerations:
+ return request, nil
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go
index 30cc902e..fca10e7c 100644
--- a/relay/channel/volcengine/constants.go
+++ b/relay/channel/volcengine/constants.go
@@ -8,6 +8,7 @@ var ModelList = []string{
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
+ "doubao-seedream-4-0-250828",
}
var ChannelName = "volcengine"
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index dbdc6ee1..59d12abe 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
info.UpstreamModelName = request.Model
}
+ if info.ChannelSetting.SystemPrompt != "" {
+ if request.System == nil {
+ request.SetStringSystem(info.ChannelSetting.SystemPrompt)
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ if request.IsStringSystem() {
+ existing := strings.TrimSpace(request.GetStringSystem())
+ if existing == "" {
+ request.SetStringSystem(info.ChannelSetting.SystemPrompt)
+ } else {
+ request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
+ }
+ } else {
+ systemContents := request.ParseSystem()
+ newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
+ newSystem.SetText(info.ChannelSetting.SystemPrompt)
+ if len(systemContents) == 0 {
+ request.System = []dto.ClaudeMediaMessage{newSystem}
+ } else {
+ request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
+ }
+ }
+ }
+ }
+
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
index cf6d08dd..3a721b47 100644
--- a/relay/common/relay_utils.go
+++ b/relay/common/relay_utils.go
@@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
req.Images = []string{req.Image}
}
+ if req.HasImage() {
+ action = constant.TaskActionGenerate
+ if info.ChannelType == constant.ChannelTypeVidu {
+ // vidu 增加 首尾帧生视频和参考图生视频
+ if len(req.Images) == 2 {
+ action = constant.TaskActionFirstTailGenerate
+ } else if len(req.Images) > 2 {
+ action = constant.TaskActionReferenceGenerate
+ }
+ }
+ }
+
storeTaskRequest(c, info, action, req)
return nil
}
-
-func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
- hasPrompt, ok := requestObj.(HasPrompt)
- if !ok {
- return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
- }
-
- if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
- return taskErr
- }
-
- action := constant.TaskActionTextGenerate
- if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
- action = constant.TaskActionGenerate
- }
-
- storeTaskRequest(c, info, action, requestObj)
- return nil
-}
-
-func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
- var req TaskSubmitReq
- if err := c.ShouldBindJSON(&req); err != nil {
- return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
- }
-
- return ValidateTaskRequestWithImage(c, info, req)
-}
diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go
index 01ab1fff..38b820f7 100644
--- a/relay/compatible_handler.go
+++ b/relay/compatible_handler.go
@@ -90,41 +90,43 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if info.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
- request := convertedRequest.(*dto.GeneralOpenAIRequest)
- containSystemPrompt := false
- for _, message := range request.Messages {
- if message.Role == request.GetSystemRoleName() {
- containSystemPrompt = true
- break
- }
- }
- if !containSystemPrompt {
- // 如果没有系统提示,则添加系统提示
- systemMessage := dto.Message{
- Role: request.GetSystemRoleName(),
- Content: info.ChannelSetting.SystemPrompt,
- }
- request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
- } else if info.ChannelSetting.SystemPromptOverride {
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
- // 如果有系统提示,且允许覆盖,则拼接到前面
- for i, message := range request.Messages {
+ request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
+ if ok {
+ containSystemPrompt := false
+ for _, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
- if message.IsStringContent() {
- request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
- } else {
- contents := message.ParseContent()
- contents = append([]dto.MediaContent{
- {
- Type: dto.ContentTypeText,
- Text: info.ChannelSetting.SystemPrompt,
- },
- }, contents...)
- request.Messages[i].Content = contents
- }
+ containSystemPrompt = true
break
}
}
+ if !containSystemPrompt {
+ // 如果没有系统提示,则添加系统提示
+ systemMessage := dto.Message{
+ Role: request.GetSystemRoleName(),
+ Content: info.ChannelSetting.SystemPrompt,
+ }
+ request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ // 如果有系统提示,且允许覆盖,则拼接到前面
+ for i, message := range request.Messages {
+ if message.Role == request.GetSystemRoleName() {
+ if message.IsStringContent() {
+ request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
+ } else {
+ contents := message.ParseContent()
+ contents = append([]dto.MediaContent{
+ {
+ Type: dto.ContentTypeText,
+ Text: info.ChannelSetting.SystemPrompt,
+ },
+ }, contents...)
+ request.Messages[i].Content = contents
+ }
+ break
+ }
+ }
+ }
}
}
@@ -276,6 +278,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
fileSearchTool.CallCount, dFileSearchQuota.String())
}
}
+ var dImageGenerationCallQuota decimal.Decimal
+ var imageGenerationCallPrice float64
+ if ctx.GetBool("image_generation_call") {
+ imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
+ dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
+ }
var quotaCalculateDecimal decimal.Decimal
@@ -331,6 +340,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
// 添加 audio input 独立计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
+ // 添加 image generation call 计费
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
@@ -429,6 +440,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice
}
+ if !dImageGenerationCallQuota.IsZero() {
+ other["image_generation_call"] = true
+ other["image_generation_call_price"] = imageGenerationCallPrice
+ }
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
index 0252d657..1410da60 100644
--- a/relay/gemini_handler.go
+++ b/relay/gemini_handler.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/gemini"
@@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
adaptor.Init(info)
+ if info.ChannelSetting.SystemPrompt != "" {
+ if request.SystemInstructions == nil {
+ request.SystemInstructions = &dto.GeminiChatContent{
+ Parts: []dto.GeminiPart{
+ {Text: info.ChannelSetting.SystemPrompt},
+ },
+ }
+ } else if len(request.SystemInstructions.Parts) == 0 {
+ request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ merged := false
+ for i := range request.SystemInstructions.Parts {
+ if request.SystemInstructions.Parts[i].Text == "" {
+ continue
+ }
+ request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
+ merged = true
+ break
+ }
+ if !merged {
+ request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
+ }
+ }
+ }
+
// Clean up empty system instruction
if request.SystemInstructions != nil {
hasContent := false
diff --git a/relay/helper/price.go b/relay/helper/price.go
index fdc5b66d..c23c068b 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -52,6 +52,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var cacheRatio float64
var imageRatio float64
var cacheCreationRatio float64
+ var audioRatio float64
+ var audioCompletionRatio float64
if !usePrice {
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
if meta.MaxTokens != 0 {
@@ -73,6 +75,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
+ audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
+ audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
@@ -90,6 +94,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
+ AudioRatio: audioRatio,
+ AudioCompletionRatio: audioCompletionRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}
diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go
index 4d1c1f9b..f4a290ec 100644
--- a/relay/helper/valid_request.go
+++ b/relay/helper/valid_request.go
@@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
case types.RelayFormatOpenAI:
request, err = GetAndValidateTextRequest(c, relayMode)
case types.RelayFormatGemini:
- request, err = GetAndValidateGeminiRequest(c)
+ if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
+ request, err = GetAndValidateGeminiEmbeddingRequest(c)
+ } else {
+ request, err = GetAndValidateGeminiRequest(c)
+ }
case types.RelayFormatClaude:
request, err = GetAndValidateClaudeRequest(c)
case types.RelayFormatOpenAIResponses:
@@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA
}
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
-
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
@@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
return request, nil
}
+
+func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
+ request := &dto.GeminiEmbeddingRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ return request, nil
+}
diff --git a/service/cf_worker.go b/service/download.go
similarity index 59%
rename from service/cf_worker.go
rename to service/download.go
index d60b6fad..036c43af 100644
--- a/service/cf_worker.go
+++ b/service/download.go
@@ -28,6 +28,12 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
return nil, fmt.Errorf("only support https url")
}
+ // SSRF防护:验证请求URL
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return nil, fmt.Errorf("request reject: %v", err)
+ }
+
workerUrl := system_setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
@@ -51,7 +57,13 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
}
return DoWorkerRequest(req)
} else {
- common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
+ // SSRF防护:验证请求URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return nil, fmt.Errorf("request reject: %v", err)
+ }
+
+ common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
return http.Get(originUrl)
}
}
diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go
index 3cfabc1a..0cf53513 100644
--- a/service/pre_consume_quota.go
+++ b/service/pre_consume_quota.go
@@ -19,7 +19,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
gopool.Go(func() {
relayInfoCopy := *relayInfo
- err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
+ err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
if err != nil {
common.SysLog("error return pre-consumed quota: " + err.Error())
}
diff --git a/service/user_notify.go b/service/user_notify.go
index 972ca655..fba12d9d 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -113,6 +113,12 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
}
} else {
+ // SSRF防护:验证Bark URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return fmt.Errorf("request reject: %v", err)
+ }
+
// 直接发送请求
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
if err != nil {
diff --git a/service/webhook.go b/service/webhook.go
index 9c6ec810..c678b863 100644
--- a/service/webhook.go
+++ b/service/webhook.go
@@ -8,6 +8,7 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "one-api/common"
"one-api/dto"
"one-api/setting/system_setting"
"time"
@@ -86,6 +87,12 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
} else {
+ // SSRF防护:验证Webhook URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return fmt.Errorf("request reject: %v", err)
+ }
+
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return fmt.Errorf("failed to create webhook request: %v", err)
diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go
index 549a1862..5b89d6fe 100644
--- a/setting/operation_setting/tools.go
+++ b/setting/operation_setting/tools.go
@@ -10,6 +10,18 @@ const (
FileSearchPrice = 2.5
)
+const (
+ GPTImage1Low1024x1024 = 0.011
+ GPTImage1Low1024x1536 = 0.016
+ GPTImage1Low1536x1024 = 0.016
+ GPTImage1Medium1024x1024 = 0.042
+ GPTImage1Medium1024x1536 = 0.063
+ GPTImage1Medium1536x1024 = 0.063
+ GPTImage1High1024x1024 = 0.167
+ GPTImage1High1024x1536 = 0.25
+ GPTImage1High1536x1024 = 0.25
+)
+
const (
// Gemini Audio Input Price
Gemini25FlashPreviewInputAudioPrice = 1.00
@@ -65,3 +77,31 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
}
return 0
}
+
+func GetGPTImage1PriceOnceCall(quality string, size string) float64 {
+ prices := map[string]map[string]float64{
+ "low": {
+ "1024x1024": GPTImage1Low1024x1024,
+ "1024x1536": GPTImage1Low1024x1536,
+ "1536x1024": GPTImage1Low1536x1024,
+ },
+ "medium": {
+ "1024x1024": GPTImage1Medium1024x1024,
+ "1024x1536": GPTImage1Medium1024x1536,
+ "1536x1024": GPTImage1Medium1536x1024,
+ },
+ "high": {
+ "1024x1024": GPTImage1High1024x1024,
+ "1024x1536": GPTImage1High1024x1536,
+ "1536x1024": GPTImage1High1536x1024,
+ },
+ }
+
+ if qualityMap, exists := prices[quality]; exists {
+ if price, exists := qualityMap[size]; exists {
+ return price
+ }
+ }
+
+ return GPTImage1High1024x1024
+}
diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go
index f06cd71e..362c6fa1 100644
--- a/setting/ratio_setting/model_ratio.go
+++ b/setting/ratio_setting/model_ratio.go
@@ -178,6 +178,7 @@ var defaultModelRatio = map[string]float64{
"gemini-2.5-flash-lite-preview-thinking-*": 0.05,
"gemini-2.5-flash-lite-preview-06-17": 0.05,
"gemini-2.5-flash": 0.15,
+ "gemini-embedding-001": 0.075,
"text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
@@ -278,6 +279,18 @@ var defaultModelPrice = map[string]float64{
"mj_upload": 0.05,
}
+var defaultAudioRatio = map[string]float64{
+ "gpt-4o-audio-preview": 16,
+ "gpt-4o-mini-audio-preview": 66.67,
+ "gpt-4o-realtime-preview": 8,
+ "gpt-4o-mini-realtime-preview": 16.67,
+}
+
+var defaultAudioCompletionRatio = map[string]float64{
+ "gpt-4o-realtime": 2,
+ "gpt-4o-mini-realtime": 2,
+}
+
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
@@ -326,6 +339,15 @@ func InitRatioSettings() {
imageRatioMap = defaultImageRatio
imageRatioMapMutex.Unlock()
+ // initialize audioRatioMap
+ audioRatioMapMutex.Lock()
+ audioRatioMap = defaultAudioRatio
+ audioRatioMapMutex.Unlock()
+
+ // initialize audioCompletionRatioMap
+ audioCompletionRatioMapMutex.Lock()
+ audioCompletionRatioMap = defaultAudioCompletionRatio
+ audioCompletionRatioMapMutex.Unlock()
}
func GetModelPriceMap() map[string]float64 {
@@ -417,6 +439,18 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
+func GetDefaultImageRatioMap() map[string]float64 {
+ return defaultImageRatio
+}
+
+func GetDefaultAudioRatioMap() map[string]float64 {
+ return defaultAudioRatio
+}
+
+func GetDefaultAudioCompletionRatioMap() map[string]float64 {
+ return defaultAudioCompletionRatio
+}
+
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
@@ -584,32 +618,22 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
}
func GetAudioRatio(name string) float64 {
- if strings.Contains(name, "-realtime") {
- if strings.HasSuffix(name, "gpt-4o-realtime-preview") {
- return 8
- } else if strings.Contains(name, "gpt-4o-mini-realtime-preview") {
- return 10 / 0.6
- } else {
- return 20
- }
- }
- if strings.Contains(name, "-audio") {
- if strings.HasPrefix(name, "gpt-4o-audio-preview") {
- return 40 / 2.5
- } else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") {
- return 10 / 0.15
- } else {
- return 40
- }
+ audioRatioMapMutex.RLock()
+ defer audioRatioMapMutex.RUnlock()
+ name = FormatMatchingModelName(name)
+ if ratio, ok := audioRatioMap[name]; ok {
+ return ratio
}
return 20
}
func GetAudioCompletionRatio(name string) float64 {
- if strings.HasPrefix(name, "gpt-4o-realtime") {
- return 2
- } else if strings.HasPrefix(name, "gpt-4o-mini-realtime") {
- return 2
+ audioCompletionRatioMapMutex.RLock()
+ defer audioCompletionRatioMapMutex.RUnlock()
+ name = FormatMatchingModelName(name)
+ if ratio, ok := audioCompletionRatioMap[name]; ok {
+
+ return ratio
}
return 2
}
@@ -630,6 +654,14 @@ var defaultImageRatio = map[string]float64{
}
var imageRatioMap map[string]float64
var imageRatioMapMutex sync.RWMutex
+var (
+ audioRatioMap map[string]float64 = nil
+ audioRatioMapMutex = sync.RWMutex{}
+)
+var (
+ audioCompletionRatioMap map[string]float64 = nil
+ audioCompletionRatioMapMutex = sync.RWMutex{}
+)
func ImageRatio2JSONString() string {
imageRatioMapMutex.RLock()
@@ -658,6 +690,71 @@ func GetImageRatio(name string) (float64, bool) {
return ratio, true
}
+func AudioRatio2JSONString() string {
+ audioRatioMapMutex.RLock()
+ defer audioRatioMapMutex.RUnlock()
+ jsonBytes, err := common.Marshal(audioRatioMap)
+ if err != nil {
+ common.SysError("error marshalling audio ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateAudioRatioByJSONString(jsonStr string) error {
+
+ tmp := make(map[string]float64)
+ if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
+ return err
+ }
+ audioRatioMapMutex.Lock()
+ audioRatioMap = tmp
+ audioRatioMapMutex.Unlock()
+ InvalidateExposedDataCache()
+ return nil
+}
+
+func GetAudioRatioCopy() map[string]float64 {
+ audioRatioMapMutex.RLock()
+ defer audioRatioMapMutex.RUnlock()
+ copyMap := make(map[string]float64, len(audioRatioMap))
+ for k, v := range audioRatioMap {
+ copyMap[k] = v
+ }
+ return copyMap
+}
+
+func AudioCompletionRatio2JSONString() string {
+ audioCompletionRatioMapMutex.RLock()
+ defer audioCompletionRatioMapMutex.RUnlock()
+ jsonBytes, err := common.Marshal(audioCompletionRatioMap)
+ if err != nil {
+ common.SysError("error marshalling audio completion ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
+ tmp := make(map[string]float64)
+ if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
+ return err
+ }
+ audioCompletionRatioMapMutex.Lock()
+ audioCompletionRatioMap = tmp
+ audioCompletionRatioMapMutex.Unlock()
+ InvalidateExposedDataCache()
+ return nil
+}
+
+func GetAudioCompletionRatioCopy() map[string]float64 {
+ audioCompletionRatioMapMutex.RLock()
+ defer audioCompletionRatioMapMutex.RUnlock()
+ copyMap := make(map[string]float64, len(audioCompletionRatioMap))
+ for k, v := range audioCompletionRatioMap {
+ copyMap[k] = v
+ }
+ return copyMap
+}
+
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
diff --git a/setting/system_setting/fetch_setting.go b/setting/system_setting/fetch_setting.go
new file mode 100644
index 00000000..c41b930a
--- /dev/null
+++ b/setting/system_setting/fetch_setting.go
@@ -0,0 +1,34 @@
+package system_setting
+
+import "one-api/setting/config"
+
+type FetchSetting struct {
+ EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护
+ AllowPrivateIp bool `json:"allow_private_ip"`
+ DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式
+ IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式
+ DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com
+ IpList []string `json:"ip_list"` // CIDR format
+ AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000
+ ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性)
+}
+
+var defaultFetchSetting = FetchSetting{
+ EnableSSRFProtection: true, // 默认开启SSRF防护
+ AllowPrivateIp: false,
+ DomainFilterMode: false,
+ IpFilterMode: false,
+ DomainList: []string{},
+ IpList: []string{},
+ AllowedPorts: []string{"80", "443", "8080", "8443"},
+ ApplyIPFilterForDomain: false,
+}
+
+func init() {
+ // 注册到全局配置管理器
+ config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting)
+}
+
+func GetFetchSetting() *FetchSetting {
+ return &defaultFetchSetting
+}
diff --git a/types/error.go b/types/error.go
index 883ee064..a42e8438 100644
--- a/types/error.go
+++ b/types/error.go
@@ -122,6 +122,9 @@ func (e *NewAPIError) MaskSensitiveError() string {
return string(e.errorCode)
}
errStr := e.Err.Error()
+ if e.errorCode == ErrorCodeCountTokenFailed {
+ return errStr
+ }
return common.MaskSensitiveInfo(errStr)
}
@@ -153,8 +156,9 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
Code: e.errorCode,
}
}
-
- result.Message = common.MaskSensitiveInfo(result.Message)
+ if e.errorCode != ErrorCodeCountTokenFailed {
+ result.Message = common.MaskSensitiveInfo(result.Message)
+ }
return result
}
@@ -178,7 +182,9 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
Type: string(e.errorType),
}
}
- result.Message = common.MaskSensitiveInfo(result.Message)
+ if e.errorCode != ErrorCodeCountTokenFailed {
+ result.Message = common.MaskSensitiveInfo(result.Message)
+ }
return result
}
diff --git a/types/price_data.go b/types/price_data.go
index f6a92d7e..ec7fcdfe 100644
--- a/types/price_data.go
+++ b/types/price_data.go
@@ -15,6 +15,8 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
+ AudioRatio float64
+ AudioCompletionRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
@@ -27,5 +29,5 @@ type PerCallPriceData struct {
}
func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
}
diff --git a/web/jsconfig.json b/web/jsconfig.json
new file mode 100644
index 00000000..ced4d054
--- /dev/null
+++ b/web/jsconfig.json
@@ -0,0 +1,9 @@
+{
+ "compilerOptions": {
+ "baseUrl": "./",
+ "paths": {
+ "@/*": ["src/*"]
+ }
+ },
+ "include": ["src/**/*"]
+}
\ No newline at end of file
diff --git a/web/src/components/layout/headerbar/UserArea.jsx b/web/src/components/layout/headerbar/UserArea.jsx
index 8ea70f47..9fc011da 100644
--- a/web/src/components/layout/headerbar/UserArea.jsx
+++ b/web/src/components/layout/headerbar/UserArea.jsx
@@ -17,7 +17,7 @@ along with this program. If not, see
+ {i18next.t('图片生成调用:${{price}} / 1次', { + price: imageGenerationCallPrice, + })} +
+ )}
{(() => {
// 构建输入部分描述
@@ -1211,6 +1220,16 @@ export function renderModelPrice(
},
)
: '',
+ imageGenerationCall && imageGenerationCallPrice > 0
+ ? i18next.t(
+ ' + 图片生成调用 ${{price}} / 1次 * {{ratioType}} {{ratio}}',
+ {
+ price: imageGenerationCallPrice,
+ ratio: groupRatio,
+ ratioType: ratioLabel,
+ },
+ )
+ : '',
].join('');
return i18next.t(
diff --git a/web/src/hooks/common/useSidebar.js b/web/src/hooks/common/useSidebar.js
index 5dce44f9..13d76fd8 100644
--- a/web/src/hooks/common/useSidebar.js
+++ b/web/src/hooks/common/useSidebar.js
@@ -21,6 +21,10 @@ import { useState, useEffect, useMemo, useContext } from 'react';
import { StatusContext } from '../../context/Status';
import { API } from '../../helpers';
+// 创建一个全局事件系统来同步所有useSidebar实例
+const sidebarEventTarget = new EventTarget();
+const SIDEBAR_REFRESH_EVENT = 'sidebar-refresh';
+
export const useSidebar = () => {
const [statusState] = useContext(StatusContext);
const [userConfig, setUserConfig] = useState(null);
@@ -124,9 +128,12 @@ export const useSidebar = () => {
// 刷新用户配置的方法(供外部调用)
const refreshUserConfig = async () => {
- if (Object.keys(adminConfig).length > 0) {
+ if (Object.keys(adminConfig).length > 0) {
await loadUserConfig();
}
+
+ // 触发全局刷新事件,通知所有useSidebar实例更新
+ sidebarEventTarget.dispatchEvent(new CustomEvent(SIDEBAR_REFRESH_EVENT));
};
// 加载用户配置
@@ -137,6 +144,21 @@ export const useSidebar = () => {
}
}, [adminConfig]);
+ // 监听全局刷新事件
+ useEffect(() => {
+ const handleRefresh = () => {
+ if (Object.keys(adminConfig).length > 0) {
+ loadUserConfig();
+ }
+ };
+
+ sidebarEventTarget.addEventListener(SIDEBAR_REFRESH_EVENT, handleRefresh);
+
+ return () => {
+ sidebarEventTarget.removeEventListener(SIDEBAR_REFRESH_EVENT, handleRefresh);
+ };
+ }, [adminConfig]);
+
// 计算最终的显示配置
const finalConfig = useMemo(() => {
const result = {};
diff --git a/web/src/hooks/dashboard/useDashboardStats.jsx b/web/src/hooks/dashboard/useDashboardStats.jsx
index aa9677a5..dbf3b67e 100644
--- a/web/src/hooks/dashboard/useDashboardStats.jsx
+++ b/web/src/hooks/dashboard/useDashboardStats.jsx
@@ -102,7 +102,7 @@ export const useDashboardStats = (
},
{
title: t('统计Tokens'),
- value: isNaN(consumeTokens) ? 0 : consumeTokens,
+ value: isNaN(consumeTokens) ? 0 : consumeTokens.toLocaleString(),
icon: