diff --git a/README.fr.md b/README.fr.md
new file mode 100644
index 00000000..de788ede
--- /dev/null
+++ b/README.fr.md
@@ -0,0 +1,216 @@
+
+
+## 📝 Description du projet
+
+> [!NOTE]
+> Il s'agit d'un projet open-source développé sur la base de [One API](https://github.com/songquanpeng/one-api)
+
+> [!IMPORTANT]
+> - Ce projet est uniquement destiné à des fins d'apprentissage personnel, sans garantie de stabilité ni de support technique.
+> - Les utilisateurs doivent se conformer aux [Conditions d'utilisation](https://openai.com/policies/terms-of-use) d'OpenAI et aux **lois et réglementations applicables**, et ne doivent pas l'utiliser à des fins illégales.
+> - Conformément aux [《Mesures provisoires pour la gestion des services d'intelligence artificielle générative》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), veuillez ne fournir aucun service d'IA générative non enregistré au public en Chine.
+
+
+
+## 📚 Documentation
+
+Pour une documentation détaillée, veuillez consulter notre Wiki officiel : [https://docs.newapi.pro/](https://docs.newapi.pro/)
+
+Vous pouvez également accéder au DeepWiki généré par l'IA :
+[](https://deepwiki.com/QuantumNous/new-api)
+
+## ✨ Fonctionnalités clés
+
+New API offre un large éventail de fonctionnalités, veuillez vous référer à [Présentation des fonctionnalités](https://docs.newapi.pro/wiki/features-introduction) pour plus de détails :
+
+1. 🎨 Nouvelle interface utilisateur
+2. 🌍 Prise en charge multilingue
+3. 💰 Fonctionnalité de recharge en ligne (YiPay)
+4. 🔍 Prise en charge de la recherche de quotas d'utilisation avec des clés (fonctionne avec [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
+5. 🔄 Compatible avec la base de données originale de One API
+6. 💵 Prise en charge de la tarification des modèles de paiement à l'utilisation
+7. ⚖️ Prise en charge de la sélection aléatoire pondérée des canaux
+8. 📈 Tableau de bord des données (console)
+9. 🔒 Regroupement de jetons et restrictions de modèles
+10. 🤖 Prise en charge de plus de méthodes de connexion par autorisation (LinuxDO, Telegram, OIDC)
+11. 🔄 Prise en charge des modèles Rerank (Cohere et Jina), [Documentation de l'API](https://docs.newapi.pro/api/jinaai-rerank)
+12. ⚡ Prise en charge de l'API OpenAI Realtime (y compris les canaux Azure), [Documentation de l'API](https://docs.newapi.pro/api/openai-realtime)
+13. ⚡ Prise en charge du format Claude Messages, [Documentation de l'API](https://docs.newapi.pro/api/anthropic-chat)
+14. Prise en charge de l'accès à l'interface de discussion via la route /chat2link
+15. 🧠 Prise en charge de la définition de l'effort de raisonnement via les suffixes de nom de modèle :
+ 1. Modèles de la série o d'OpenAI
+ - Ajouter le suffixe `-high` pour un effort de raisonnement élevé (par exemple : `o3-mini-high`)
+ - Ajouter le suffixe `-medium` pour un effort de raisonnement moyen (par exemple : `o3-mini-medium`)
+ - Ajouter le suffixe `-low` pour un effort de raisonnement faible (par exemple : `o3-mini-low`)
+ 2. Modèles de pensée de Claude
+ - Ajouter le suffixe `-thinking` pour activer le mode de pensée (par exemple : `claude-3-7-sonnet-20250219-thinking`)
+16. 🔄 Fonctionnalité de la pensée au contenu
+17. 🔄 Limitation du débit du modèle pour les utilisateurs
+18. 💰 Prise en charge de la facturation du cache, qui permet de facturer à un ratio défini lorsque le cache est atteint :
+ 1. Définir l'option `Ratio de cache d'invite` dans `Paramètres système->Paramètres de fonctionnement`
+ 2. Définir le `Ratio de cache d'invite` dans le canal, plage de 0 à 1, par exemple, le définir sur 0,5 signifie facturer à 50 % lorsque le cache est atteint
+ 3. Canaux pris en charge :
+ - [x] OpenAI
+ - [x] Azure
+ - [x] DeepSeek
+ - [x] Claude
+
+## Prise en charge des modèles
+
+Cette version prend en charge plusieurs modèles, veuillez vous référer à [Documentation de l'API-Interface de relais](https://docs.newapi.pro/api) pour plus de détails :
+
+1. Modèles tiers **gpts** (gpt-4-gizmo-*)
+2. Canal tiers [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy), [Documentation de l'API](https://docs.newapi.pro/api/midjourney-proxy-image)
+3. Canal tiers [Suno API](https://github.com/Suno-API/Suno-API), [Documentation de l'API](https://docs.newapi.pro/api/suno-music)
+4. Canaux personnalisés, prenant en charge la saisie complète de l'adresse d'appel
+5. Modèles Rerank ([Cohere](https://cohere.ai/) et [Jina](https://jina.ai/)), [Documentation de l'API](https://docs.newapi.pro/api/jinaai-rerank)
+6. Format de messages Claude, [Documentation de l'API](https://docs.newapi.pro/api/anthropic-chat)
+7. Dify, ne prend actuellement en charge que chatflow
+
+## Configuration des variables d'environnement
+
+Pour des instructions de configuration détaillées, veuillez vous référer à [Guide d'installation-Configuration des variables d'environnement](https://docs.newapi.pro/installation/environment-variables) :
+
+- `GENERATE_DEFAULT_TOKEN` : S'il faut générer des jetons initiaux pour les utilisateurs nouvellement enregistrés, la valeur par défaut est `false`
+- `STREAMING_TIMEOUT` : Délai d'expiration de la réponse en streaming, la valeur par défaut est de 300 secondes
+- `DIFY_DEBUG` : S'il faut afficher les informations sur le flux de travail et les nœuds pour les canaux Dify, la valeur par défaut est `true`
+- `FORCE_STREAM_OPTION` : S'il faut remplacer le paramètre client stream_options, la valeur par défaut est `true`
+- `GET_MEDIA_TOKEN` : S'il faut compter les jetons d'image, la valeur par défaut est `true`
+- `GET_MEDIA_TOKEN_NOT_STREAM` : S'il faut compter les jetons d'image dans les cas sans streaming, la valeur par défaut est `true`
+- `UPDATE_TASK` : S'il faut mettre à jour les tâches asynchrones (Midjourney, Suno), la valeur par défaut est `true`
+- `COHERE_SAFETY_SETTING` : Paramètres de sécurité du modèle Cohere, les options sont `NONE`, `CONTEXTUAL`, `STRICT`, la valeur par défaut est `NONE`
+- `GEMINI_VISION_MAX_IMAGE_NUM` : Nombre maximum d'images pour les modèles Gemini, la valeur par défaut est `16`
+- `MAX_FILE_DOWNLOAD_MB` : Taille maximale de téléchargement de fichier en Mo, la valeur par défaut est `20`
+- `CRYPTO_SECRET` : Clé de chiffrement utilisée pour chiffrer le contenu de la base de données
+- `AZURE_DEFAULT_API_VERSION` : Version de l'API par défaut du canal Azure, la valeur par défaut est `2025-04-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE` : Durée de la limite de notification, la valeur par défaut est de `10` minutes
+- `NOTIFY_LIMIT_COUNT` : Nombre maximal de notifications utilisateur dans la durée spécifiée, la valeur par défaut est `2`
+- `ERROR_LOG_ENABLED=true` : S'il faut enregistrer et afficher les journaux d'erreurs, la valeur par défaut est `false`
+
+## Déploiement
+
+Pour des guides de déploiement détaillés, veuillez vous référer à [Guide d'installation-Méthodes de déploiement](https://docs.newapi.pro/installation) :
+
+> [!TIP]
+> Dernière image Docker : `calciumion/new-api:latest`
+
+### Considérations sur le déploiement multi-machines
+- La variable d'environnement `SESSION_SECRET` doit être définie, sinon l'état de connexion sera incohérent sur plusieurs machines
+- Si vous partagez Redis, `CRYPTO_SECRET` doit être défini, sinon le contenu de Redis ne pourra pas être consulté sur plusieurs machines
+
+### Exigences de déploiement
+- Base de données locale (par défaut) : SQLite (le déploiement Docker doit monter le répertoire `/data`)
+- Base de données distante : MySQL version >= 5.7.8, PgSQL version >= 9.6
+
+### Méthodes de déploiement
+
+#### Utilisation de la fonctionnalité Docker du panneau BaoTa
+Installez le panneau BaoTa (version **9.2.0** ou supérieure), recherchez **New-API** dans le magasin d'applications et installez-le.
+[Tutoriel avec des images](./docs/BT.md)
+
+#### Utilisation de Docker Compose (recommandé)
+```shell
+# Télécharger le projet
+git clone https://github.com/Calcium-Ion/new-api.git
+cd new-api
+# Modifier docker-compose.yml si nécessaire
+# Démarrer
+docker-compose up -d
+```
+
+#### Utilisation directe de l'image Docker
+```shell
+# Utilisation de SQLite
+docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
+
+# Utilisation de MySQL
+docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
+```
+
+## Nouvelle tentative de canal et cache
+La fonctionnalité de nouvelle tentative de canal a été implémentée, vous pouvez définir le nombre de tentatives dans `Paramètres->Paramètres de fonctionnement->Paramètres généraux`. Il est **recommandé d'activer la mise en cache**.
+
+### Méthode de configuration du cache
+1. `REDIS_CONN_STRING` : Définir Redis comme cache
+2. `MEMORY_CACHE_ENABLED` : Activer le cache mémoire (pas besoin de le définir manuellement si Redis est défini)
+
+## Documentation de l'API
+
+Pour une documentation détaillée de l'API, veuillez vous référer à [Documentation de l'API](https://docs.newapi.pro/api) :
+
+- [API de discussion](https://docs.newapi.pro/api/openai-chat)
+- [API d'image](https://docs.newapi.pro/api/openai-image)
+- [API de rerank](https://docs.newapi.pro/api/jinaai-rerank)
+- [API en temps réel](https://docs.newapi.pro/api/openai-realtime)
+- [API de discussion Claude (messages)](https://docs.newapi.pro/api/anthropic-chat)
+
+## Projets connexes
+- [One API](https://github.com/songquanpeng/one-api) : Projet original
+- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy) : Prise en charge de l'interface Midjourney
+- [chatnio](https://github.com/Deeptrain-Community/chatnio) : Solution B/C unique d'IA de nouvelle génération
+- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) : Interroger le quota d'utilisation avec une clé
+
+Autres projets basés sur New API :
+- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) : Version optimisée hautes performances de New API
+- [VoAPI](https://github.com/VoAPI/VoAPI) : Version embellie du frontend basée sur New API
+
+## Aide et support
+
+Si vous avez des questions, veuillez vous référer à [Aide et support](https://docs.newapi.pro/support) :
+- [Interaction avec la communauté](https://docs.newapi.pro/support/community-interaction)
+- [Commentaires sur les problèmes](https://docs.newapi.pro/support/feedback-issues)
+- [FAQ](https://docs.newapi.pro/support/faq)
+
+## 🌟 Historique des étoiles
+
+[](https://star-history.com/#Calcium-Ion/new-api&Date)
\ No newline at end of file
diff --git a/README.md b/README.md
index d68b3e13..2103fe8f 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
diff --git a/common/api_type.go b/common/api_type.go
index 5ac46c86..855eef84 100644
--- a/common/api_type.go
+++ b/common/api_type.go
@@ -67,6 +67,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeJimeng
case constant.ChannelTypeMoonshot:
apiType = constant.APITypeMoonshot
+ case constant.ChannelTypeSubmodel:
+ apiType = constant.APITypeSubmodel
}
if apiType == -1 {
return constant.APITypeOpenAI, false
diff --git a/common/database.go b/common/database.go
index 71dbd94d..38a54d5e 100644
--- a/common/database.go
+++ b/common/database.go
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false
var UsingClickHouse = false
-var SQLitePath = "one-api.db?_busy_timeout=30000"
+var SQLitePath = "one-api.db?_busy_timeout=30000"
\ No newline at end of file
diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go
index ffc26350..25f9c68e 100644
--- a/common/endpoint_defaults.go
+++ b/common/endpoint_defaults.go
@@ -23,6 +23,7 @@ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
+ constant.EndpointTypeEmbeddings: {Path: "/v1/embeddings", Method: "POST"},
}
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
diff --git a/common/ip.go b/common/ip.go
new file mode 100644
index 00000000..bfb64ee7
--- /dev/null
+++ b/common/ip.go
@@ -0,0 +1,22 @@
+package common
+
+import "net"
+
+func IsPrivateIP(ip net.IP) bool {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+
+ private := []net.IPNet{
+ {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
+ {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
+ {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
+ }
+
+ for _, privateNet := range private {
+ if privateNet.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go
new file mode 100644
index 00000000..6f7d289f
--- /dev/null
+++ b/common/ssrf_protection.go
@@ -0,0 +1,327 @@
+package common
+
+import (
+ "fmt"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+)
+
+// SSRFProtection SSRF防护配置
+type SSRFProtection struct {
+ AllowPrivateIp bool
+ DomainFilterMode bool // true: 白名单, false: 黑名单
+ DomainList []string // domain format, e.g. example.com, *.example.com
+ IpFilterMode bool // true: 白名单, false: 黑名单
+ IpList []string // CIDR or single IP
+ AllowedPorts []int // 允许的端口范围
+ ApplyIPFilterForDomain bool // 对域名启用IP过滤
+}
+
+// DefaultSSRFProtection 默认SSRF防护配置
+var DefaultSSRFProtection = &SSRFProtection{
+ AllowPrivateIp: false,
+ DomainFilterMode: true,
+ DomainList: []string{},
+ IpFilterMode: true,
+ IpList: []string{},
+ AllowedPorts: []int{},
+}
+
+// isPrivateIP 检查IP是否为私有地址
+func isPrivateIP(ip net.IP) bool {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+
+ // 检查私有网段
+ private := []net.IPNet{
+ {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
+ {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
+ {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
+ {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
+ {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
+ {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
+ {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
+ }
+
+ for _, privateNet := range private {
+ if privateNet.Contains(ip) {
+ return true
+ }
+ }
+
+ // 检查IPv6私有地址
+ if ip.To4() == nil {
+ // IPv6 loopback
+ if ip.Equal(net.IPv6loopback) {
+ return true
+ }
+ // IPv6 link-local
+ if strings.HasPrefix(ip.String(), "fe80:") {
+ return true
+ }
+ // IPv6 unique local
+ if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
+ return true
+ }
+ }
+
+ return false
+}
+
+// parsePortRanges 解析端口范围配置
+// 支持格式: "80", "443", "8000-9000"
+func parsePortRanges(portConfigs []string) ([]int, error) {
+ var ports []int
+
+ for _, config := range portConfigs {
+ config = strings.TrimSpace(config)
+ if config == "" {
+ continue
+ }
+
+ if strings.Contains(config, "-") {
+ // 处理端口范围 "8000-9000"
+ parts := strings.Split(config, "-")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid port range format: %s", config)
+ }
+
+ startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
+ if err != nil {
+ return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
+ }
+
+ endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
+ if err != nil {
+ return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
+ }
+
+ if startPort > endPort {
+ return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
+ }
+
+ if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
+ return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
+ }
+
+ // 添加范围内的所有端口
+ for port := startPort; port <= endPort; port++ {
+ ports = append(ports, port)
+ }
+ } else {
+ // 处理单个端口 "80"
+ port, err := strconv.Atoi(config)
+ if err != nil {
+ return nil, fmt.Errorf("invalid port number: %s", config)
+ }
+
+ if port < 1 || port > 65535 {
+ return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
+ }
+
+ ports = append(ports, port)
+ }
+ }
+
+ return ports, nil
+}
+
+// isAllowedPort 检查端口是否被允许
+func (p *SSRFProtection) isAllowedPort(port int) bool {
+ if len(p.AllowedPorts) == 0 {
+ return true // 如果没有配置端口限制,则允许所有端口
+ }
+
+ for _, allowedPort := range p.AllowedPorts {
+ if port == allowedPort {
+ return true
+ }
+ }
+ return false
+}
+
+// isDomainWhitelisted 检查域名是否在白名单中
+func isDomainListed(domain string, list []string) bool {
+ if len(list) == 0 {
+ return false
+ }
+
+ domain = strings.ToLower(domain)
+ for _, item := range list {
+ item = strings.ToLower(strings.TrimSpace(item))
+ if item == "" {
+ continue
+ }
+ // 精确匹配
+ if domain == item {
+ return true
+ }
+ // 通配符匹配 (*.example.com)
+ if strings.HasPrefix(item, "*.") {
+ suffix := strings.TrimPrefix(item, "*.")
+ if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+func (p *SSRFProtection) isDomainAllowed(domain string) bool {
+ listed := isDomainListed(domain, p.DomainList)
+ if p.DomainFilterMode { // 白名单
+ return listed
+ }
+ // 黑名单
+ return !listed
+}
+
+// isIPWhitelisted 检查IP是否在白名单中
+
+func isIPListed(ip net.IP, list []string) bool {
+ if len(list) == 0 {
+ return false
+ }
+
+ for _, whitelistCIDR := range list {
+ _, network, err := net.ParseCIDR(whitelistCIDR)
+ if err != nil {
+ // 尝试作为单个IP处理
+ if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
+ if ip.Equal(whitelistIP) {
+ return true
+ }
+ }
+ continue
+ }
+
+ if network.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+// IsIPAccessAllowed 检查IP是否允许访问
+func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
+ // 私有IP限制
+ if isPrivateIP(ip) && !p.AllowPrivateIp {
+ return false
+ }
+
+ listed := isIPListed(ip, p.IpList)
+ if p.IpFilterMode { // 白名单
+ return listed
+ }
+ // 黑名单
+ return !listed
+}
+
+// ValidateURL 验证URL是否安全
+func (p *SSRFProtection) ValidateURL(urlStr string) error {
+ // 解析URL
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return fmt.Errorf("invalid URL format: %v", err)
+ }
+
+ // 只允许HTTP/HTTPS协议
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
+ }
+
+ // 解析主机和端口
+ host, portStr, err := net.SplitHostPort(u.Host)
+ if err != nil {
+ // 没有端口,使用默认端口
+ host = u.Hostname()
+ if u.Scheme == "https" {
+ portStr = "443"
+ } else {
+ portStr = "80"
+ }
+ }
+
+ // 验证端口
+ port, err := strconv.Atoi(portStr)
+ if err != nil {
+ return fmt.Errorf("invalid port: %s", portStr)
+ }
+
+ if !p.isAllowedPort(port) {
+ return fmt.Errorf("port %d is not allowed", port)
+ }
+
+ // 如果 host 是 IP,则跳过域名检查
+ if ip := net.ParseIP(host); ip != nil {
+ if !p.IsIPAccessAllowed(ip) {
+ if isPrivateIP(ip) {
+ return fmt.Errorf("private IP address not allowed: %s", ip.String())
+ }
+ if p.IpFilterMode {
+ return fmt.Errorf("ip not in whitelist: %s", ip.String())
+ }
+ return fmt.Errorf("ip in blacklist: %s", ip.String())
+ }
+ return nil
+ }
+
+ // 先进行域名过滤
+ if !p.isDomainAllowed(host) {
+ if p.DomainFilterMode {
+ return fmt.Errorf("domain not in whitelist: %s", host)
+ }
+ return fmt.Errorf("domain in blacklist: %s", host)
+ }
+
+ // 若未启用对域名应用IP过滤,则到此通过
+ if !p.ApplyIPFilterForDomain {
+ return nil
+ }
+
+ // 解析域名对应IP并检查
+ ips, err := net.LookupIP(host)
+ if err != nil {
+ return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
+ }
+ for _, ip := range ips {
+ if !p.IsIPAccessAllowed(ip) {
+ if isPrivateIP(ip) && !p.AllowPrivateIp {
+ return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
+ }
+ if p.IpFilterMode {
+ return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
+ }
+ return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
+ }
+ }
+ return nil
+}
+
+// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
+func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
+ // 如果SSRF防护被禁用,直接返回成功
+ if !enableSSRFProtection {
+ return nil
+ }
+
+ // 解析端口范围配置
+ allowedPortInts, err := parsePortRanges(allowedPorts)
+ if err != nil {
+ return fmt.Errorf("request reject - invalid port configuration: %v", err)
+ }
+
+ protection := &SSRFProtection{
+ AllowPrivateIp: allowPrivateIp,
+ DomainFilterMode: domainFilterMode,
+ DomainList: domainList,
+ IpFilterMode: ipFilterMode,
+ IpList: ipList,
+ AllowedPorts: allowedPortInts,
+ ApplyIPFilterForDomain: applyIPFilterForDomain,
+ }
+ return protection.ValidateURL(urlStr)
+}
diff --git a/common/sys_log.go b/common/sys_log.go
index 478015f0..b29adc3e 100644
--- a/common/sys_log.go
+++ b/common/sys_log.go
@@ -2,9 +2,10 @@ package common
import (
"fmt"
- "github.com/gin-gonic/gin"
"os"
"time"
+
+ "github.com/gin-gonic/gin"
)
func SysLog(s string) {
@@ -22,3 +23,33 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
+
+func LogStartupSuccess(startTime time.Time, port string) {
+
+ duration := time.Since(startTime)
+ durationMs := duration.Milliseconds()
+
+ // Get network IPs
+ networkIps := GetNetworkIps()
+
+ // Print blank line for spacing
+ fmt.Fprintf(gin.DefaultWriter, "\n")
+
+ // Print the main success message
+ fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
+ fmt.Fprintf(gin.DefaultWriter, "\n")
+
+ // Skip fancy startup message in container environments
+ if !IsRunningInContainer() {
+ // Print local URL
+ fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
+ }
+
+ // Print network URLs
+ for _, ip := range networkIps {
+ fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
+ }
+
+ // Print blank line for spacing
+ fmt.Fprintf(gin.DefaultWriter, "\n")
+}
diff --git a/common/utils.go b/common/utils.go
index 883abfd1..21f72ec6 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -68,6 +68,78 @@ func GetIp() (ip string) {
return
}
+func GetNetworkIps() []string {
+ var networkIps []string
+ ips, err := net.InterfaceAddrs()
+ if err != nil {
+ log.Println(err)
+ return networkIps
+ }
+
+ for _, a := range ips {
+ if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+ if ipNet.IP.To4() != nil {
+ ip := ipNet.IP.String()
+ // Include common private network ranges
+ if strings.HasPrefix(ip, "10.") ||
+ strings.HasPrefix(ip, "172.") ||
+ strings.HasPrefix(ip, "192.168.") {
+ networkIps = append(networkIps, ip)
+ }
+ }
+ }
+ }
+ return networkIps
+}
+
+// IsRunningInContainer detects if the application is running inside a container
+func IsRunningInContainer() bool {
+ // Method 1: Check for .dockerenv file (Docker containers)
+ if _, err := os.Stat("/.dockerenv"); err == nil {
+ return true
+ }
+
+ // Method 2: Check cgroup for container indicators
+ if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
+ content := string(data)
+ if strings.Contains(content, "docker") ||
+ strings.Contains(content, "containerd") ||
+ strings.Contains(content, "kubepods") ||
+ strings.Contains(content, "/lxc/") {
+ return true
+ }
+ }
+
+ // Method 3: Check environment variables commonly set by container runtimes
+ containerEnvVars := []string{
+ "KUBERNETES_SERVICE_HOST",
+ "DOCKER_CONTAINER",
+ "container",
+ }
+
+ for _, envVar := range containerEnvVars {
+ if os.Getenv(envVar) != "" {
+ return true
+ }
+ }
+
+ // Method 4: Check if init process is not the traditional init
+ if data, err := os.ReadFile("/proc/1/comm"); err == nil {
+ comm := strings.TrimSpace(string(data))
+ // In containers, process 1 is often not "init" or "systemd"
+ if comm != "init" && comm != "systemd" {
+ // Additional check: if it's a common container entrypoint
+ if strings.Contains(comm, "docker") ||
+ strings.Contains(comm, "containerd") ||
+ strings.Contains(comm, "runc") {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
diff --git a/constant/api_type.go b/constant/api_type.go
index f62d91d5..0ea5048f 100644
--- a/constant/api_type.go
+++ b/constant/api_type.go
@@ -31,6 +31,7 @@ const (
APITypeXai
APITypeCoze
APITypeJimeng
- APITypeMoonshot // this one is only for count, do not add any channel after this
- APITypeDummy // this one is only for count, do not add any channel after this
+ APITypeMoonshot
+ APITypeSubmodel
+ APITypeDummy // this one is only for count, do not add any channel after this
)
diff --git a/constant/channel.go b/constant/channel.go
index 2e1cc5b0..34fb20f4 100644
--- a/constant/channel.go
+++ b/constant/channel.go
@@ -50,8 +50,10 @@ const (
ChannelTypeKling = 50
ChannelTypeJimeng = 51
ChannelTypeVidu = 52
+ ChannelTypeSubmodel = 53
ChannelTypeDummy // this one is only for count, do not add any channel after this
+
)
var ChannelBaseURLs = []string{
@@ -108,4 +110,5 @@ var ChannelBaseURLs = []string{
"https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51
"https://api.vidu.cn", //52
+ "https://llm.submodel.ai", //53
}
diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go
index ef096b75..f799e5ba 100644
--- a/constant/endpoint_type.go
+++ b/constant/endpoint_type.go
@@ -9,6 +9,7 @@ const (
EndpointTypeGemini EndpointType = "gemini"
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeImageGeneration EndpointType = "image-generation"
+ EndpointTypeEmbeddings EndpointType = "embeddings"
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
//EndpointTypeSuno EndpointType = "suno-proxy"
//EndpointTypeKling EndpointType = "kling"
diff --git a/constant/task.go b/constant/task.go
index 21790145..e174fd60 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -11,8 +11,10 @@ const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
- TaskActionGenerate = "generate"
- TaskActionTextGenerate = "textGenerate"
+ TaskActionGenerate = "generate"
+ TaskActionTextGenerate = "textGenerate"
+ TaskActionFirstTailGenerate = "firstTailGenerate"
+ TaskActionReferenceGenerate = "referenceGenerate"
)
var SunoModel2Action = map[string]string{
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 18acf231..1082b9e7 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -10,7 +10,7 @@ import (
"one-api/constant"
"one-api/model"
"one-api/service"
- "one-api/setting"
+ "one-api/setting/operation_setting"
"one-api/types"
"strconv"
"time"
@@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
}
availableBalanceCny := response.Data.AvailableBalance
- availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
+ availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
channel.UpdateBalance(availableBalanceUsd)
return availableBalanceUsd, nil
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 5fc6d749..b3a3be4e 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -38,7 +38,7 @@ type testResult struct {
newAPIError *types.NewAPIError
}
-func testChannel(channel *model.Channel, testModel string) testResult {
+func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
tik := time.Now()
if channel.Type == constant.ChannelTypeMidjourney {
return testResult{
@@ -81,13 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult {
requestPath := "/v1/chat/completions"
- // 先判断是否为 Embedding 模型
- if strings.Contains(strings.ToLower(testModel), "embedding") ||
- strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
- strings.Contains(testModel, "bge-") || // bge 系列模型
- strings.Contains(testModel, "embed") ||
- channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
- requestPath = "/v1/embeddings" // 修改请求路径
+ // 如果指定了端点类型,使用指定的端点类型
+ if endpointType != "" {
+ if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
+ requestPath = endpointInfo.Path
+ }
+ } else {
+ // 如果没有指定端点类型,使用原有的自动检测逻辑
+ // 先判断是否为 Embedding 模型
+ if strings.Contains(strings.ToLower(testModel), "embedding") ||
+ strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
+ strings.Contains(testModel, "bge-") || // bge 系列模型
+ strings.Contains(testModel, "embed") ||
+ channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
+ requestPath = "/v1/embeddings" // 修改请求路径
+ }
+
+ // VolcEngine 图像生成模型
+ if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+ requestPath = "/v1/images/generations"
+ }
}
c.Request = &http.Request{
@@ -133,14 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult {
newAPIError: newAPIError,
}
}
- request := buildTestRequest(testModel)
- // Determine relay format based on request path
- relayFormat := types.RelayFormatOpenAI
- if c.Request.URL.Path == "/v1/embeddings" {
- relayFormat = types.RelayFormatEmbedding
+ // Determine relay format based on endpoint type or request path
+ var relayFormat types.RelayFormat
+ if endpointType != "" {
+ // 根据指定的端点类型设置 relayFormat
+ switch constant.EndpointType(endpointType) {
+ case constant.EndpointTypeOpenAI:
+ relayFormat = types.RelayFormatOpenAI
+ case constant.EndpointTypeOpenAIResponse:
+ relayFormat = types.RelayFormatOpenAIResponses
+ case constant.EndpointTypeAnthropic:
+ relayFormat = types.RelayFormatClaude
+ case constant.EndpointTypeGemini:
+ relayFormat = types.RelayFormatGemini
+ case constant.EndpointTypeJinaRerank:
+ relayFormat = types.RelayFormatRerank
+ case constant.EndpointTypeImageGeneration:
+ relayFormat = types.RelayFormatOpenAIImage
+ case constant.EndpointTypeEmbeddings:
+ relayFormat = types.RelayFormatEmbedding
+ default:
+ relayFormat = types.RelayFormatOpenAI
+ }
+ } else {
+ // 根据请求路径自动检测
+ relayFormat = types.RelayFormatOpenAI
+ if c.Request.URL.Path == "/v1/embeddings" {
+ relayFormat = types.RelayFormatEmbedding
+ }
+ if c.Request.URL.Path == "/v1/images/generations" {
+ relayFormat = types.RelayFormatOpenAIImage
+ }
+ if c.Request.URL.Path == "/v1/messages" {
+ relayFormat = types.RelayFormatClaude
+ }
+ if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
+ relayFormat = types.RelayFormatGemini
+ }
+ if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
+ relayFormat = types.RelayFormatRerank
+ }
+ if c.Request.URL.Path == "/v1/responses" {
+ relayFormat = types.RelayFormatOpenAIResponses
+ }
}
+ request := buildTestRequest(testModel, endpointType)
+
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
if err != nil {
@@ -163,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
testModel = info.UpstreamModelName
- request.Model = testModel
+ // 更新请求中的模型名称
+ request.SetModelName(testModel)
apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
@@ -193,17 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult {
var convertedRequest any
// 根据 RelayMode 选择正确的转换函数
- if info.RelayMode == relayconstant.RelayModeEmbeddings {
- // 创建一个 EmbeddingRequest
- embeddingRequest := dto.EmbeddingRequest{
- Input: request.Input,
- Model: request.Model,
+ switch info.RelayMode {
+ case relayconstant.RelayModeEmbeddings:
+ // Embedding 请求 - request 已经是正确的类型
+ if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
+ convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
+ } else {
+ return testResult{
+ context: c,
+ localErr: errors.New("invalid embedding request type"),
+ newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
+ }
+ }
+ case relayconstant.RelayModeImagesGenerations:
+ // 图像生成请求 - request 已经是正确的类型
+ if imageReq, ok := request.(*dto.ImageRequest); ok {
+ convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
+ } else {
+ return testResult{
+ context: c,
+ localErr: errors.New("invalid image request type"),
+ newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
+ }
+ }
+ case relayconstant.RelayModeRerank:
+ // Rerank 请求 - request 已经是正确的类型
+ if rerankReq, ok := request.(*dto.RerankRequest); ok {
+ convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
+ } else {
+ return testResult{
+ context: c,
+ localErr: errors.New("invalid rerank request type"),
+ newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
+ }
+ }
+ case relayconstant.RelayModeResponses:
+ // Response 请求 - request 已经是正确的类型
+ if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
+ convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
+ } else {
+ return testResult{
+ context: c,
+ localErr: errors.New("invalid response request type"),
+ newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
+ }
+ }
+ default:
+ // Chat/Completion 等其他请求类型
+ if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
+ convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
+ } else {
+ return testResult{
+ context: c,
+ localErr: errors.New("invalid general request type"),
+ newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
+ }
}
- // 调用专门用于 Embedding 的转换函数
- convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
- } else {
- // 对其他所有请求类型(如 Chat),保持原有逻辑
- convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
}
if err != nil {
@@ -235,7 +334,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
- err := service.RelayErrorHandler(httpResp, true)
+ err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
return testResult{
context: c,
localErr: err,
@@ -306,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
}
-func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
- testRequest := &dto.GeneralOpenAIRequest{
- Model: "", // this will be set later
- Stream: false,
+func buildTestRequest(model string, endpointType string) dto.Request {
+ // 根据端点类型构建不同的测试请求
+ if endpointType != "" {
+ switch constant.EndpointType(endpointType) {
+ case constant.EndpointTypeEmbeddings:
+ // 返回 EmbeddingRequest
+ return &dto.EmbeddingRequest{
+ Model: model,
+ Input: []any{"hello world"},
+ }
+ case constant.EndpointTypeImageGeneration:
+ // 返回 ImageRequest
+ return &dto.ImageRequest{
+ Model: model,
+ Prompt: "a cute cat",
+ N: 1,
+ Size: "1024x1024",
+ }
+ case constant.EndpointTypeJinaRerank:
+ // 返回 RerankRequest
+ return &dto.RerankRequest{
+ Model: model,
+ Query: "What is Deep Learning?",
+ Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
+ TopN: 2,
+ }
+ case constant.EndpointTypeOpenAIResponse:
+ // 返回 OpenAIResponsesRequest
+ return &dto.OpenAIResponsesRequest{
+ Model: model,
+ Input: json.RawMessage("\"hi\""),
+ }
+ case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
+ // 返回 GeneralOpenAIRequest
+ maxTokens := uint(10)
+ if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
+ maxTokens = 3000
+ }
+ return &dto.GeneralOpenAIRequest{
+ Model: model,
+ Stream: false,
+ Messages: []dto.Message{
+ {
+ Role: "user",
+ Content: "hi",
+ },
+ },
+ MaxTokens: maxTokens,
+ }
+ }
}
+ // 自动检测逻辑(保持原有行为)
// 先判断是否为 Embedding 模型
- if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
- strings.HasPrefix(model, "m3e") || // m3e 系列模型
+ if strings.Contains(strings.ToLower(model), "embedding") ||
+ strings.HasPrefix(model, "m3e") ||
strings.Contains(model, "bge-") {
- testRequest.Model = model
- // Embedding 请求
- testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
- return testRequest
+ // 返回 EmbeddingRequest
+ return &dto.EmbeddingRequest{
+ Model: model,
+ Input: []any{"hello world"},
+ }
}
- // 并非Embedding 模型
+
+ // Chat/Completion 请求 - 返回 GeneralOpenAIRequest
+ testRequest := &dto.GeneralOpenAIRequest{
+ Model: model,
+ Stream: false,
+ Messages: []dto.Message{
+ {
+ Role: "user",
+ Content: "hi",
+ },
+ },
+ }
+
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") {
@@ -334,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest.MaxTokens = 10
}
- testMessage := dto.Message{
- Role: "user",
- Content: "hi",
- }
- testRequest.Model = model
- testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
@@ -363,8 +516,9 @@ func TestChannel(c *gin.Context) {
// }
//}()
testModel := c.Query("model")
+ endpointType := c.Query("endpoint_type")
tik := time.Now()
- result := testChannel(channel, testModel)
+ result := testChannel(channel, testModel, endpointType)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -390,7 +544,6 @@ func TestChannel(c *gin.Context) {
"message": "",
"time": consumedTime,
})
- return
}
var testAllChannelsLock sync.Mutex
@@ -424,7 +577,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
- result := testChannel(channel, "")
+ result := testChannel(channel, "", "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
@@ -438,7 +591,7 @@ func testAllChannels(notify bool) error {
// 当错误检查通过,才检查响应时间
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
if milliseconds > disableThreshold {
- err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+ err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
shouldBanChannel = true
}
@@ -475,7 +628,6 @@ func TestAllChannels(c *gin.Context) {
"success": true,
"message": "",
})
- return
}
var autoTestChannelsOnce sync.Once
diff --git a/controller/channel.go b/controller/channel.go
index 70be91d4..4aedee3b 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -6,7 +6,9 @@ import (
"net/http"
"one-api/common"
"one-api/constant"
+ "one-api/dto"
"one-api/model"
+ "one-api/service"
"strconv"
"strings"
@@ -187,6 +189,8 @@ func FetchUpstreamModels(c *gin.Context) {
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
+ case constant.ChannelTypeZhipu_v4:
+ url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
default:
url = fmt.Sprintf("%s/v1/models", baseURL)
}
@@ -380,18 +384,9 @@ func GetChannel(c *gin.Context) {
return
}
-// GetChannelKey 验证2FA后获取渠道密钥
+// GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
+// 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
func GetChannelKey(c *gin.Context) {
- type GetChannelKeyRequest struct {
- Code string `json:"code" binding:"required"`
- }
-
- var req GetChannelKeyRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- common.ApiError(c, fmt.Errorf("参数错误: %v", err))
- return
- }
-
userId := c.GetInt("id")
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -399,24 +394,6 @@ func GetChannelKey(c *gin.Context) {
return
}
- // 获取2FA记录并验证
- twoFA, err := model.GetTwoFAByUserId(userId)
- if err != nil {
- common.ApiError(c, fmt.Errorf("获取2FA信息失败: %v", err))
- return
- }
-
- if twoFA == nil || !twoFA.IsEnabled {
- common.ApiError(c, fmt.Errorf("用户未启用2FA,无法查看密钥"))
- return
- }
-
- // 统一的2FA验证逻辑
- if !validateTwoFactorAuth(twoFA, req.Code) {
- common.ApiError(c, fmt.Errorf("验证码或备用码错误,请重试"))
- return
- }
-
// 获取渠道信息(包含密钥)
channel, err := model.GetChannelById(channelId, true)
if err != nil {
@@ -432,10 +409,10 @@ func GetChannelKey(c *gin.Context) {
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
- // 统一的成功响应格式
+ // 返回渠道密钥
c.JSON(http.StatusOK, gin.H{
"success": true,
- "message": "验证成功",
+ "message": "获取成功",
"data": map[string]interface{}{
"key": channel.Key,
},
@@ -500,9 +477,10 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
}
type AddChannelRequest struct {
- Mode string `json:"mode"`
- MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
- Channel *model.Channel `json:"channel"`
+ Mode string `json:"mode"`
+ MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+ BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
+ Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
@@ -560,7 +538,7 @@ func AddChannel(c *gin.Context) {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
- if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+ if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -585,7 +563,7 @@ func AddChannel(c *gin.Context) {
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
- if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+ if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
@@ -615,6 +593,13 @@ func AddChannel(c *gin.Context) {
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
+ if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
+ keyPrefix := localChannel.Key
+ if len(localChannel.Key) > 8 {
+ keyPrefix = localChannel.Key[:8]
+ }
+ localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
+ }
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
@@ -622,6 +607,7 @@ func AddChannel(c *gin.Context) {
common.ApiError(c, err)
return
}
+ service.ResetProxyClientCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -840,7 +826,7 @@ func UpdateChannel(c *gin.Context) {
}
// 处理 Vertex AI 的特殊情况
- if channel.Type == constant.ChannelTypeVertexAi {
+ if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key)
@@ -883,6 +869,7 @@ func UpdateChannel(c *gin.Context) {
return
}
model.InitChannelCache()
+ service.ResetProxyClientCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
@@ -1092,8 +1079,8 @@ func CopyChannel(c *gin.Context) {
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
- Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
- KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
+ Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
+ KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
@@ -1421,6 +1408,86 @@ func ManageMultiKeys(c *gin.Context) {
})
return
+ case "delete_key":
+ if request.KeyIndex == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "未指定要删除的密钥索引",
+ })
+ return
+ }
+
+ keyIndex := *request.KeyIndex
+ if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "密钥索引超出范围",
+ })
+ return
+ }
+
+ keys := channel.GetKeys()
+ var remainingKeys []string
+ var newStatusList = make(map[int]int)
+ var newDisabledTime = make(map[int]int64)
+ var newDisabledReason = make(map[int]string)
+
+ newIndex := 0
+ for i, key := range keys {
+ // 跳过要删除的密钥
+ if i == keyIndex {
+ continue
+ }
+
+ remainingKeys = append(remainingKeys, key)
+
+ // 保留其他密钥的状态信息,重新索引
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
+ newStatusList[newIndex] = status
+ }
+ }
+ if channel.ChannelInfo.MultiKeyDisabledTime != nil {
+ if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
+ newDisabledTime[newIndex] = t
+ }
+ }
+ if channel.ChannelInfo.MultiKeyDisabledReason != nil {
+ if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
+ newDisabledReason[newIndex] = r
+ }
+ }
+ newIndex++
+ }
+
+ if len(remainingKeys) == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "不能删除最后一个密钥",
+ })
+ return
+ }
+
+ // Update channel with remaining keys
+ channel.Key = strings.Join(remainingKeys, "\n")
+ channel.ChannelInfo.MultiKeySize = len(remainingKeys)
+ channel.ChannelInfo.MultiKeyStatusList = newStatusList
+ channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
+ channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
+
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "密钥已删除",
+ })
+ return
+
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string
diff --git a/controller/midjourney.go b/controller/midjourney.go
index a67d39c2..3a730441 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -13,6 +13,7 @@ import (
"one-api/model"
"one-api/service"
"one-api/setting"
+ "one-api/setting/system_setting"
"time"
"github.com/gin-gonic/gin"
@@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) {
if setting.MjForwardUrlEnabled {
for i, midjourney := range items {
- midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
+ midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney
}
}
@@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) {
if setting.MjForwardUrlEnabled {
for i, midjourney := range items {
- midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
+ midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney
}
}
diff --git a/controller/misc.go b/controller/misc.go
index 897dad25..07f7d3f0 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -42,6 +42,8 @@ func GetStatus(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
+ passkeySetting := system_setting.GetPasskeySettings()
+
data := gin.H{
"version": common.Version,
"start_time": common.StartTime,
@@ -58,11 +60,7 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
- "server_address": setting.ServerAddress,
- "price": setting.Price,
- "stripe_unit_price": setting.StripeUnitPrice,
- "min_topup": setting.MinTopUp,
- "stripe_min_topup": setting.StripeMinTopUp,
+ "server_address": system_setting.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
@@ -75,15 +73,15 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
- "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
- "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
- "pay_methods": setting.PayMethods,
- "usd_exchange_rate": setting.USDExchangeRate,
+
+ "usd_exchange_rate": operation_setting.USDExchangeRate,
+ "price": operation_setting.Price,
+ "stripe_unit_price": setting.StripeUnitPrice,
// 面板启用开关
"api_info_enabled": cs.ApiInfoEnabled,
@@ -98,6 +96,13 @@ func GetStatus(c *gin.Context) {
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+ "passkey_login": passkeySetting.Enabled,
+ "passkey_display_name": passkeySetting.RPDisplayName,
+ "passkey_rp_id": passkeySetting.RPID,
+ "passkey_origins": passkeySetting.Origins,
+ "passkey_allow_insecure": passkeySetting.AllowInsecureOrigin,
+ "passkey_user_verification": passkeySetting.UserVerification,
+ "passkey_attachment": passkeySetting.AttachmentPreference,
"setup": constant.Setup,
}
@@ -253,7 +258,7 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
- link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
+ link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("
您好,你正在进行%s密码重置。
"+
"
点击 此处 进行密码重置。
"+
diff --git a/controller/oidc.go b/controller/oidc.go
index f3def0e3..8e254d38 100644
--- a/controller/oidc.go
+++ b/controller/oidc.go
@@ -8,7 +8,6 @@ import (
"net/url"
"one-api/common"
"one-api/model"
- "one-api/setting"
"one-api/setting/system_setting"
"strconv"
"strings"
@@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
- values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
+ values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
diff --git a/controller/option.go b/controller/option.go
index e5f2b75b..7d1c676f 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -128,6 +128,33 @@ func UpdateOption(c *gin.Context) {
})
return
}
+ case "ImageRatio":
+ err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "图片倍率设置失败: " + err.Error(),
+ })
+ return
+ }
+ case "AudioRatio":
+ err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "音频倍率设置失败: " + err.Error(),
+ })
+ return
+ }
+ case "AudioCompletionRatio":
+ err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "音频补全倍率设置失败: " + err.Error(),
+ })
+ return
+ }
case "ModelRequestRateLimitGroup":
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
if err != nil {
diff --git a/controller/passkey.go b/controller/passkey.go
new file mode 100644
index 00000000..7ffacf5d
--- /dev/null
+++ b/controller/passkey.go
@@ -0,0 +1,497 @@
+package controller
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "strconv"
+ "time"
+
+ "one-api/common"
+ "one-api/model"
+ passkeysvc "one-api/service/passkey"
+ "one-api/setting/system_setting"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/go-webauthn/webauthn/protocol"
+ webauthnlib "github.com/go-webauthn/webauthn/webauthn"
+)
+
+func PasskeyRegisterBegin(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ user, err := getSessionUser(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ credential, err := model.GetPasskeyByUserID(user.Id)
+ if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
+ common.ApiError(c, err)
+ return
+ }
+ if errors.Is(err, model.ErrPasskeyNotFound) {
+ credential = nil
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ waUser := passkeysvc.NewWebAuthnUser(user, credential)
+ var options []webauthnlib.RegistrationOption
+ if credential != nil {
+ descriptor := credential.ToWebAuthnCredential().Descriptor()
+ options = append(options, webauthnlib.WithExclusions([]protocol.CredentialDescriptor{descriptor}))
+ }
+
+ creation, sessionData, err := wa.BeginRegistration(waUser, options...)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ if err := passkeysvc.SaveSessionData(c, passkeysvc.RegistrationSessionKey, sessionData); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "options": creation,
+ },
+ })
+}
+
+func PasskeyRegisterFinish(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ user, err := getSessionUser(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ credentialRecord, err := model.GetPasskeyByUserID(user.Id)
+ if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
+ common.ApiError(c, err)
+ return
+ }
+ if errors.Is(err, model.ErrPasskeyNotFound) {
+ credentialRecord = nil
+ }
+
+ sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.RegistrationSessionKey)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ waUser := passkeysvc.NewWebAuthnUser(user, credentialRecord)
+ credential, err := wa.FinishRegistration(waUser, *sessionData, c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ passkeyCredential := model.NewPasskeyCredentialFromWebAuthn(user.Id, credential)
+ if passkeyCredential == nil {
+ common.ApiErrorMsg(c, "无法创建 Passkey 凭证")
+ return
+ }
+
+ if err := model.UpsertPasskeyCredential(passkeyCredential); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Passkey 注册成功",
+ })
+}
+
+func PasskeyDelete(c *gin.Context) {
+ user, err := getSessionUser(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ if err := model.DeletePasskeyByUserID(user.Id); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Passkey 已解绑",
+ })
+}
+
+func PasskeyStatus(c *gin.Context) {
+ user, err := getSessionUser(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ credential, err := model.GetPasskeyByUserID(user.Id)
+ if errors.Is(err, model.ErrPasskeyNotFound) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "enabled": false,
+ },
+ })
+ return
+ }
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ data := gin.H{
+ "enabled": true,
+ "last_used_at": credential.LastUsedAt,
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": data,
+ })
+}
+
+func PasskeyLoginBegin(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ assertion, sessionData, err := wa.BeginDiscoverableLogin()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ if err := passkeysvc.SaveSessionData(c, passkeysvc.LoginSessionKey, sessionData); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "options": assertion,
+ },
+ })
+}
+
+func PasskeyLoginFinish(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.LoginSessionKey)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ handler := func(rawID, userHandle []byte) (webauthnlib.User, error) {
+ // 首先通过凭证ID查找用户
+ credential, err := model.GetPasskeyByCredentialID(rawID)
+ if err != nil {
+ return nil, fmt.Errorf("未找到 Passkey 凭证: %w", err)
+ }
+
+ // 通过凭证获取用户
+ user := &model.User{Id: credential.UserID}
+ if err := user.FillUserById(); err != nil {
+ return nil, fmt.Errorf("用户信息获取失败: %w", err)
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ return nil, errors.New("该用户已被禁用")
+ }
+
+ if len(userHandle) > 0 {
+ userID, parseErr := strconv.Atoi(string(userHandle))
+ if parseErr != nil {
+ // 记录异常但继续验证,因为某些客户端可能使用非数字格式
+ common.SysLog(fmt.Sprintf("PasskeyLogin: userHandle parse error for credential, length: %d", len(userHandle)))
+ } else if userID != user.Id {
+ return nil, errors.New("用户句柄与凭证不匹配")
+ }
+ }
+
+ return passkeysvc.NewWebAuthnUser(user, credential), nil
+ }
+
+ waUser, credential, err := wa.FinishPasskeyLogin(handler, *sessionData, c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ userWrapper, ok := waUser.(*passkeysvc.WebAuthnUser)
+ if !ok {
+ common.ApiErrorMsg(c, "Passkey 登录状态异常")
+ return
+ }
+
+ modelUser := userWrapper.ModelUser()
+ if modelUser == nil {
+ common.ApiErrorMsg(c, "Passkey 登录状态异常")
+ return
+ }
+
+ if modelUser.Status != common.UserStatusEnabled {
+ common.ApiErrorMsg(c, "该用户已被禁用")
+ return
+ }
+
+ // 更新凭证信息
+ updatedCredential := model.NewPasskeyCredentialFromWebAuthn(modelUser.Id, credential)
+ if updatedCredential == nil {
+ common.ApiErrorMsg(c, "Passkey 凭证更新失败")
+ return
+ }
+ now := time.Now()
+ updatedCredential.LastUsedAt = &now
+ if err := model.UpsertPasskeyCredential(updatedCredential); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ setupLogin(modelUser, c)
+ return
+}
+
+func AdminResetPasskey(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiErrorMsg(c, "无效的用户 ID")
+ return
+ }
+
+ user := &model.User{Id: id}
+ if err := user.FillUserById(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ if _, err := model.GetPasskeyByUserID(user.Id); err != nil {
+ if errors.Is(err, model.ErrPasskeyNotFound) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该用户尚未绑定 Passkey",
+ })
+ return
+ }
+ common.ApiError(c, err)
+ return
+ }
+
+ if err := model.DeletePasskeyByUserID(user.Id); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Passkey 已重置",
+ })
+}
+
+func PasskeyVerifyBegin(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ user, err := getSessionUser(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ credential, err := model.GetPasskeyByUserID(user.Id)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该用户尚未绑定 Passkey",
+ })
+ return
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ waUser := passkeysvc.NewWebAuthnUser(user, credential)
+ assertion, sessionData, err := wa.BeginLogin(waUser)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ if err := passkeysvc.SaveSessionData(c, passkeysvc.VerifySessionKey, sessionData); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "options": assertion,
+ },
+ })
+}
+
+func PasskeyVerifyFinish(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ user, err := getSessionUser(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ credential, err := model.GetPasskeyByUserID(user.Id)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该用户尚未绑定 Passkey",
+ })
+ return
+ }
+
+ sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.VerifySessionKey)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ waUser := passkeysvc.NewWebAuthnUser(user, credential)
+ _, err = wa.FinishLogin(waUser, *sessionData, c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 更新凭证的最后使用时间
+ now := time.Now()
+ credential.LastUsedAt = &now
+ if err := model.UpsertPasskeyCredential(credential); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Passkey 验证成功",
+ })
+}
+
+func getSessionUser(c *gin.Context) (*model.User, error) {
+ session := sessions.Default(c)
+ idRaw := session.Get("id")
+ if idRaw == nil {
+ return nil, errors.New("未登录")
+ }
+ id, ok := idRaw.(int)
+ if !ok {
+ return nil, errors.New("无效的会话信息")
+ }
+ user := &model.User{Id: id}
+ if err := user.FillUserById(); err != nil {
+ return nil, err
+ }
+ if user.Status != common.UserStatusEnabled {
+ return nil, errors.New("该用户已被禁用")
+ }
+ return user, nil
+}
diff --git a/controller/relay.go b/controller/relay.go
index d3d93192..23d72515 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
- preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return
}
defer func() {
// Only return quota if downstream failed and quota was actually pre-consumed
- if newAPIError != nil && preConsumedQuota != 0 {
- service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
+ if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
+ service.ReturnPreConsumedQuota(c, relayInfo)
}
}()
@@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
-
- gopool.Go(func() {
- // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
- // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
- if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+ // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
+ // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
+ if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+ gopool.Go(func() {
service.DisableChannel(channelError, err.Error())
- }
- })
+ })
+ }
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中
diff --git a/controller/secure_verification.go b/controller/secure_verification.go
new file mode 100644
index 00000000..1c5f0981
--- /dev/null
+++ b/controller/secure_verification.go
@@ -0,0 +1,313 @@
+package controller
+
+import (
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ passkeysvc "one-api/service/passkey"
+ "one-api/setting/system_setting"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ // SecureVerificationSessionKey 安全验证的 session key
+ SecureVerificationSessionKey = "secure_verified_at"
+ // SecureVerificationTimeout 验证有效期(秒)
+ SecureVerificationTimeout = 300 // 5分钟
+)
+
+type UniversalVerifyRequest struct {
+ Method string `json:"method"` // "2fa" 或 "passkey"
+ Code string `json:"code,omitempty"`
+}
+
+type VerificationStatusResponse struct {
+ Verified bool `json:"verified"`
+ ExpiresAt int64 `json:"expires_at,omitempty"`
+}
+
+// UniversalVerify 通用验证接口
+// 支持 2FA 和 Passkey 验证,验证成功后在 session 中记录时间戳
+func UniversalVerify(c *gin.Context) {
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "未登录",
+ })
+ return
+ }
+
+ var req UniversalVerifyRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ common.ApiError(c, fmt.Errorf("参数错误: %v", err))
+ return
+ }
+
+ // 获取用户信息
+ user := &model.User{Id: userId}
+ if err := user.FillUserById(); err != nil {
+ common.ApiError(c, fmt.Errorf("获取用户信息失败: %v", err))
+ return
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ common.ApiError(c, fmt.Errorf("该用户已被禁用"))
+ return
+ }
+
+ // 检查用户的验证方式
+ twoFA, _ := model.GetTwoFAByUserId(userId)
+ has2FA := twoFA != nil && twoFA.IsEnabled
+
+ passkey, passkeyErr := model.GetPasskeyByUserID(userId)
+ hasPasskey := passkeyErr == nil && passkey != nil
+
+ if !has2FA && !hasPasskey {
+ common.ApiError(c, fmt.Errorf("用户未启用2FA或Passkey"))
+ return
+ }
+
+ // 根据验证方式进行验证
+ var verified bool
+ var verifyMethod string
+
+ switch req.Method {
+ case "2fa":
+ if !has2FA {
+ common.ApiError(c, fmt.Errorf("用户未启用2FA"))
+ return
+ }
+ if req.Code == "" {
+ common.ApiError(c, fmt.Errorf("验证码不能为空"))
+ return
+ }
+ verified = validateTwoFactorAuth(twoFA, req.Code)
+ verifyMethod = "2FA"
+
+ case "passkey":
+ if !hasPasskey {
+ common.ApiError(c, fmt.Errorf("用户未启用Passkey"))
+ return
+ }
+ // Passkey 验证需要先调用 PasskeyVerifyBegin 和 PasskeyVerifyFinish
+ // 这里只是验证 Passkey 验证流程是否已经完成
+ // 实际上,前端应该先调用这两个接口,然后再调用本接口
+ verified = true // Passkey 验证逻辑已在 PasskeyVerifyFinish 中完成
+ verifyMethod = "Passkey"
+
+ default:
+ common.ApiError(c, fmt.Errorf("不支持的验证方式: %s", req.Method))
+ return
+ }
+
+ if !verified {
+ common.ApiError(c, fmt.Errorf("验证失败,请检查验证码"))
+ return
+ }
+
+ // 验证成功,在 session 中记录时间戳
+ session := sessions.Default(c)
+ now := time.Now().Unix()
+ session.Set(SecureVerificationSessionKey, now)
+ if err := session.Save(); err != nil {
+ common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
+ return
+ }
+
+ // 记录日志
+ model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("通用安全验证成功 (验证方式: %s)", verifyMethod))
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "验证成功",
+ "data": gin.H{
+ "verified": true,
+ "expires_at": now + SecureVerificationTimeout,
+ },
+ })
+}
+
+// GetVerificationStatus 获取验证状态
+func GetVerificationStatus(c *gin.Context) {
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "未登录",
+ })
+ return
+ }
+
+ session := sessions.Default(c)
+ verifiedAtRaw := session.Get(SecureVerificationSessionKey)
+
+ if verifiedAtRaw == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": VerificationStatusResponse{
+ Verified: false,
+ },
+ })
+ return
+ }
+
+ verifiedAt, ok := verifiedAtRaw.(int64)
+ if !ok {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": VerificationStatusResponse{
+ Verified: false,
+ },
+ })
+ return
+ }
+
+ elapsed := time.Now().Unix() - verifiedAt
+ if elapsed >= SecureVerificationTimeout {
+ // 验证已过期
+ session.Delete(SecureVerificationSessionKey)
+ _ = session.Save()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": VerificationStatusResponse{
+ Verified: false,
+ },
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": VerificationStatusResponse{
+ Verified: true,
+ ExpiresAt: verifiedAt + SecureVerificationTimeout,
+ },
+ })
+}
+
+// CheckSecureVerification 检查是否已通过安全验证
+// 返回 true 表示验证有效,false 表示需要重新验证
+func CheckSecureVerification(c *gin.Context) bool {
+ session := sessions.Default(c)
+ verifiedAtRaw := session.Get(SecureVerificationSessionKey)
+
+ if verifiedAtRaw == nil {
+ return false
+ }
+
+ verifiedAt, ok := verifiedAtRaw.(int64)
+ if !ok {
+ return false
+ }
+
+ elapsed := time.Now().Unix() - verifiedAt
+ if elapsed >= SecureVerificationTimeout {
+ // 验证已过期,清除 session
+ session.Delete(SecureVerificationSessionKey)
+ _ = session.Save()
+ return false
+ }
+
+ return true
+}
+
+// PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
+// 这是一个辅助函数,供 PasskeyVerifyFinish 调用
+func PasskeyVerifyAndSetSession(c *gin.Context) {
+ session := sessions.Default(c)
+ now := time.Now().Unix()
+ session.Set(SecureVerificationSessionKey, now)
+ _ = session.Save()
+}
+
+// PasskeyVerifyForSecure 用于安全验证的 Passkey 验证流程
+// 整合了 begin 和 finish 流程
+func PasskeyVerifyForSecure(c *gin.Context) {
+ if !system_setting.GetPasskeySettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未启用 Passkey 登录",
+ })
+ return
+ }
+
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "未登录",
+ })
+ return
+ }
+
+ user := &model.User{Id: userId}
+ if err := user.FillUserById(); err != nil {
+ common.ApiError(c, fmt.Errorf("获取用户信息失败: %v", err))
+ return
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ common.ApiError(c, fmt.Errorf("该用户已被禁用"))
+ return
+ }
+
+ credential, err := model.GetPasskeyByUserID(userId)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该用户尚未绑定 Passkey",
+ })
+ return
+ }
+
+ wa, err := passkeysvc.BuildWebAuthn(c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ waUser := passkeysvc.NewWebAuthnUser(user, credential)
+ sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.VerifySessionKey)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ _, err = wa.FinishLogin(waUser, *sessionData, c.Request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 更新凭证的最后使用时间
+ now := time.Now()
+ credential.LastUsedAt = &now
+ if err := model.UpsertPasskeyCredential(credential); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 验证成功,设置 session
+ PasskeyVerifyAndSetSession(c)
+
+ // 记录日志
+ model.RecordLog(userId, model.LogTypeSystem, "Passkey 安全验证成功")
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Passkey 验证成功",
+ "data": gin.H{
+ "verified": true,
+ "expires_at": time.Now().Unix() + SecureVerificationTimeout,
+ },
+ })
+}
diff --git a/controller/setup.go b/controller/setup.go
index 8943a1a0..3ae255e9 100644
--- a/controller/setup.go
+++ b/controller/setup.go
@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
func PostSetup(c *gin.Context) {
// Check if setup is already completed
if constant.Setup {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "系统已经初始化完成",
})
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
var req SetupRequest
err := c.ShouldBindJSON(&req)
if err != nil {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "请求参数有误",
})
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
if !rootExists {
// Validate username length: max 12 characters to align with model.User validation
if len(req.Username) > 12 {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "用户名长度不能超过12个字符",
})
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
}
// Validate password
if req.Password != req.ConfirmPassword {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "两次输入的密码不一致",
})
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
}
if len(req.Password) < 8 {
- c.JSON(400, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "密码长度至少为8个字符",
})
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
// Create root user
hashedPassword, err := common.Password2Hash(req.Password)
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "系统错误: " + err.Error(),
})
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
}
err = model.DB.Create(&rootUser).Error
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "创建管理员账号失败: " + err.Error(),
})
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
// Save operation modes to database for persistence
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "保存自用模式设置失败: " + err.Error(),
})
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "保存演示站点模式设置失败: " + err.Error(),
})
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
}
err = model.DB.Create(&setup).Error
if err != nil {
- c.JSON(500, gin.H{
+ c.JSON(200, gin.H{
"success": false,
"message": "系统初始化失败: " + err.Error(),
})
@@ -178,4 +178,4 @@ func boolToString(b bool) string {
return "true"
}
return "false"
-}
+}
\ No newline at end of file
diff --git a/controller/task_video.go b/controller/task_video.go
index 84b78f90..73d5c39b 100644
--- a/controller/task_video.go
+++ b/controller/task_video.go
@@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
} else {
- task.Data = responseBody
+ task.Data = redactVideoResponseBody(responseBody)
}
now := time.Now().Unix()
@@ -117,7 +117,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
if task.FinishTime == 0 {
task.FinishTime = now
}
- task.FailReason = taskResult.Url
+ if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
+ task.FailReason = taskResult.Url
+ }
case model.TaskStatusFailure:
task.Status = model.TaskStatusFailure
task.Progress = "100%"
@@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return nil
}
+
+func redactVideoResponseBody(body []byte) []byte {
+ var m map[string]any
+ if err := json.Unmarshal(body, &m); err != nil {
+ return body
+ }
+ resp, _ := m["response"].(map[string]any)
+ if resp != nil {
+ delete(resp, "bytesBase64Encoded")
+ if v, ok := resp["video"].(string); ok {
+ resp["video"] = truncateBase64(v)
+ }
+ if vs, ok := resp["videos"].([]any); ok {
+ for i := range vs {
+ if vm, ok := vs[i].(map[string]any); ok {
+ delete(vm, "bytesBase64Encoded")
+ }
+ }
+ }
+ }
+ b, err := json.Marshal(m)
+ if err != nil {
+ return body
+ }
+ return b
+}
+
+func truncateBase64(s string) string {
+ const maxKeep = 256
+ if len(s) <= maxKeep {
+ return s
+ }
+ return s[:maxKeep] + "..."
+}
diff --git a/controller/telegram.go b/controller/telegram.go
index 8d07fc94..2b1ec4fc 100644
--- a/controller/telegram.go
+++ b/controller/telegram.go
@@ -65,7 +65,7 @@ func TelegramBind(c *gin.Context) {
return
}
- c.Redirect(302, "/setting")
+ c.Redirect(302, "/console/personal")
}
func TelegramLogin(c *gin.Context) {
diff --git a/controller/topup.go b/controller/topup.go
index 3f3c8623..243e6794 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -9,6 +9,8 @@ import (
"one-api/model"
"one-api/service"
"one-api/setting"
+ "one-api/setting/operation_setting"
+ "one-api/setting/system_setting"
"strconv"
"sync"
"time"
@@ -19,6 +21,44 @@ import (
"github.com/shopspring/decimal"
)
+func GetTopUpInfo(c *gin.Context) {
+ // 获取支付方式
+ payMethods := operation_setting.PayMethods
+
+ // 如果启用了 Stripe 支付,添加到支付方法列表
+ if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
+ // 检查是否已经包含 Stripe
+ hasStripe := false
+ for _, method := range payMethods {
+ if method["type"] == "stripe" {
+ hasStripe = true
+ break
+ }
+ }
+
+ if !hasStripe {
+ stripeMethod := map[string]string{
+ "name": "Stripe",
+ "type": "stripe",
+ "color": "rgba(var(--semi-purple-5), 1)",
+ "min_topup": strconv.Itoa(setting.StripeMinTopUp),
+ }
+ payMethods = append(payMethods, stripeMethod)
+ }
+ }
+
+ data := gin.H{
+ "enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
+ "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
+ "pay_methods": payMethods,
+ "min_topup": operation_setting.MinTopUp,
+ "stripe_min_topup": setting.StripeMinTopUp,
+ "amount_options": operation_setting.GetPaymentSetting().AmountOptions,
+ "discount": operation_setting.GetPaymentSetting().AmountDiscount,
+ }
+ common.ApiSuccess(c, data)
+}
+
type EpayRequest struct {
Amount int64 `json:"amount"`
PaymentMethod string `json:"payment_method"`
@@ -31,13 +71,13 @@ type AmountRequest struct {
}
func GetEpayClient() *epay.Client {
- if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
+ if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
return nil
}
withUrl, err := epay.NewClient(&epay.Config{
- PartnerID: setting.EpayId,
- Key: setting.EpayKey,
- }, setting.PayAddress)
+ PartnerID: operation_setting.EpayId,
+ Key: operation_setting.EpayKey,
+ }, operation_setting.PayAddress)
if err != nil {
return nil
}
@@ -58,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
}
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
- dPrice := decimal.NewFromFloat(setting.Price)
+ dPrice := decimal.NewFromFloat(operation_setting.Price)
+ // apply optional preset discount by the original request amount (if configured), default 1.0
+ discount := 1.0
+ if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
+ if ds > 0 {
+ discount = ds
+ }
+ }
+ dDiscount := decimal.NewFromFloat(discount)
- payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
+ payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
return payMoney.InexactFloat64()
}
func getMinTopup() int64 {
- minTopup := setting.MinTopUp
+ minTopup := operation_setting.MinTopUp
if !common.DisplayInCurrencyEnabled {
dMinTopup := decimal.NewFromInt(int64(minTopup))
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
@@ -99,13 +147,13 @@ func RequestEpay(c *gin.Context) {
return
}
- if !setting.ContainsPayMethod(req.PaymentMethod) {
+ if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
return
}
callBackAddress := service.GetCallbackAddress()
- returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
+ returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go
index eb320809..9a568d85 100644
--- a/controller/topup_stripe.go
+++ b/controller/topup_stripe.go
@@ -8,6 +8,8 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
+ "one-api/setting/operation_setting"
+ "one-api/setting/system_setting"
"strconv"
"strings"
"time"
@@ -215,15 +217,16 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
- SuccessURL: stripe.String(setting.ServerAddress + "/log"),
- CancelURL: stripe.String(setting.ServerAddress + "/topup"),
+ SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
+ CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(setting.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
- Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
+ Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
+ AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
}
if "" == customerId {
@@ -254,6 +257,7 @@ func GetChargedAmount(count float64, user model.User) float64 {
}
func getStripePayMoney(amount float64, group string) float64 {
+ originalAmount := amount
if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit
}
@@ -262,7 +266,14 @@ func getStripePayMoney(amount float64, group string) float64 {
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
- payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
+ // apply optional preset discount by the original request amount (if configured), default 1.0
+ discount := 1.0
+ if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
+ if ds > 0 {
+ discount = ds
+ }
+ }
+ payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
return payMoney
}
diff --git a/controller/user.go b/controller/user.go
index 982329ce..33d4636b 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -450,6 +450,10 @@ func GetSelf(c *gin.Context) {
"role": user.Role,
"status": user.Status,
"email": user.Email,
+ "github_id": user.GitHubId,
+ "oidc_id": user.OidcId,
+ "wechat_id": user.WeChatId,
+ "telegram_id": user.TelegramId,
"group": user.Group,
"quota": user.Quota,
"used_quota": user.UsedQuota,
@@ -1098,6 +1102,9 @@ type UpdateUserSettingRequest struct {
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
BarkUrl string `json:"bark_url,omitempty"`
+ GotifyUrl string `json:"gotify_url,omitempty"`
+ GotifyToken string `json:"gotify_token,omitempty"`
+ GotifyPriority int `json:"gotify_priority,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
RecordIpLog bool `json:"record_ip_log"`
}
@@ -1113,7 +1120,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 验证预警类型
- if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
+ if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark && req.QuotaWarningType != dto.NotifyTypeGotify {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
@@ -1188,6 +1195,40 @@ func UpdateUserSetting(c *gin.Context) {
}
}
+ // 如果是Gotify类型,验证Gotify URL和Token
+ if req.QuotaWarningType == dto.NotifyTypeGotify {
+ if req.GotifyUrl == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Gotify服务器地址不能为空",
+ })
+ return
+ }
+ if req.GotifyToken == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Gotify令牌不能为空",
+ })
+ return
+ }
+ // 验证URL格式
+ if _, err := url.ParseRequestURI(req.GotifyUrl); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的Gotify服务器地址",
+ })
+ return
+ }
+ // 检查是否是HTTP或HTTPS
+ if !strings.HasPrefix(req.GotifyUrl, "https://") && !strings.HasPrefix(req.GotifyUrl, "http://") {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Gotify服务器地址必须以http://或https://开头",
+ })
+ return
+ }
+ }
+
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
@@ -1221,6 +1262,18 @@ func UpdateUserSetting(c *gin.Context) {
settings.BarkUrl = req.BarkUrl
}
+ // 如果是Gotify类型,添加Gotify配置到设置中
+ if req.QuotaWarningType == dto.NotifyTypeGotify {
+ settings.GotifyUrl = req.GotifyUrl
+ settings.GotifyToken = req.GotifyToken
+ // Gotify优先级范围0-10,超出范围则使用默认值5
+ if req.GotifyPriority < 0 || req.GotifyPriority > 10 {
+ settings.GotifyPriority = 5
+ } else {
+ settings.GotifyPriority = req.GotifyPriority
+ }
+ }
+
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
diff --git a/dto/channel_settings.go b/dto/channel_settings.go
index 2c58795c..d57184b3 100644
--- a/dto/channel_settings.go
+++ b/dto/channel_settings.go
@@ -9,6 +9,25 @@ type ChannelSettings struct {
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
}
+type VertexKeyType string
+
+const (
+ VertexKeyTypeJSON VertexKeyType = "json"
+ VertexKeyTypeAPIKey VertexKeyType = "api_key"
+)
+
type ChannelOtherSettings struct {
- AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
+ AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
+ VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
+ OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
+ AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
+ DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
+ AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
+}
+
+func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
+ if s == nil || s.OpenRouterEnterprise == nil {
+ return false
+ }
+ return *s.OpenRouterEnterprise
}
diff --git a/dto/claude.go b/dto/claude.go
index 963e588b..dfc5cfd4 100644
--- a/dto/claude.go
+++ b/dto/claude.go
@@ -195,11 +195,15 @@ type ClaudeRequest struct {
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
- //ClaudeMetadata `json:"metadata,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Tools any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- Thinking *Thinking `json:"thinking,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ContextManagement json.RawMessage `json:"context_management,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ Thinking *Thinking `json:"thinking,omitempty"`
+ McpServers json.RawMessage `json:"mcp_servers,omitempty"`
+ Metadata json.RawMessage `json:"metadata,omitempty"`
+ // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
+ ServiceTier string `json:"service_tier,omitempty"`
}
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
diff --git a/dto/gemini.go b/dto/gemini.go
index cd5d74cd..80552aad 100644
--- a/dto/gemini.go
+++ b/dto/gemini.go
@@ -2,12 +2,11 @@ package dto
import (
"encoding/json"
+ "github.com/gin-gonic/gin"
"one-api/common"
"one-api/logger"
"one-api/types"
"strings"
-
- "github.com/gin-gonic/gin"
)
type GeminiChatRequest struct {
@@ -15,7 +14,30 @@ type GeminiChatRequest struct {
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
+ ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
+ CachedContent string `json:"cachedContent,omitempty"`
+}
+
+type ToolConfig struct {
+ FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
+ RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
+}
+
+type FunctionCallingConfig struct {
+ Mode FunctionCallingConfigMode `json:"mode,omitempty"`
+ AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
+}
+type FunctionCallingConfigMode string
+
+type RetrievalConfig struct {
+ LatLng *LatLng `json:"latLng,omitempty"`
+ LanguageCode string `json:"languageCode,omitempty"`
+}
+
+type LatLng struct {
+ Latitude *float64 `json:"latitude,omitempty"`
+ Longitude *float64 `json:"longitude,omitempty"`
}
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
@@ -229,6 +251,7 @@ type GeminiChatTool struct {
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+ URLContext any `json:"urlContext,omitempty"`
}
type GeminiChatGenerationConfig struct {
@@ -240,12 +263,20 @@ type GeminiChatGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
+ ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
+ PresencePenalty *float32 `json:"presencePenalty,omitempty"`
+ FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
+ ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
+ Logprobs *int32 `json:"logprobs,omitempty"`
+ MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
+type MediaResolution string
+
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason *string `json:"finishReason"`
@@ -269,15 +300,14 @@ type GeminiChatResponse struct {
}
type GeminiUsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount"`
- CandidatesTokenCount int `json:"candidatesTokenCount"`
- TotalTokenCount int `json:"totalTokenCount"`
- ThoughtsTokenCount int `json:"thoughtsTokenCount"`
- PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"`
- CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"`
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
}
-type GeminiModalityTokenCount struct {
+type GeminiPromptTokensDetails struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
diff --git a/dto/openai_image.go b/dto/openai_image.go
index 9e838688..5aece25f 100644
--- a/dto/openai_image.go
+++ b/dto/openai_image.go
@@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
return nil
}
+// 序列化时需要重新把字段平铺
+func (r ImageRequest) MarshalJSON() ([]byte, error) {
+ // 将已定义字段转为 map
+ type Alias ImageRequest
+ alias := Alias(r)
+ base, err := common.Marshal(alias)
+ if err != nil {
+ return nil, err
+ }
+
+ var baseMap map[string]json.RawMessage
+ if err := common.Unmarshal(base, &baseMap); err != nil {
+ return nil, err
+ }
+
+ // 合并 ExtraFields
+ for k, v := range r.Extra {
+ if _, exists := baseMap[k]; !exists {
+ baseMap[k] = v
+ }
+ }
+
+ return json.Marshal(baseMap)
+}
+
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
fields := make(map[string]struct{})
for i := 0; i < t.NumField(); i++ {
diff --git a/dto/openai_request.go b/dto/openai_request.go
index cd05a63c..dbdfad44 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -57,6 +57,18 @@ type GeneralOpenAIRequest struct {
Dimensions int `json:"dimensions,omitempty"`
Modalities json.RawMessage `json:"modalities,omitempty"`
Audio json.RawMessage `json:"audio,omitempty"`
+ // 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户
+ // 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤以保护用户隐私
+ SafetyIdentifier string `json:"safety_identifier,omitempty"`
+ // Whether or not to store the output of this chat completion request for use in our model distillation or evals products.
+ // 是否存储此次请求数据供 OpenAI 用于评估和优化产品
+ // 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用
+ Store json.RawMessage `json:"store,omitempty"`
+ // Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field
+ PromptCacheKey string `json:"prompt_cache_key,omitempty"`
+ LogitBias json.RawMessage `json:"logit_bias,omitempty"`
+ Metadata json.RawMessage `json:"metadata,omitempty"`
+ Prediction json.RawMessage `json:"prediction,omitempty"`
// gemini
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
//xai
@@ -772,21 +784,23 @@ type OpenAIResponsesRequest struct {
Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
- ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
+ ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
- ServiceTier string `json:"service_tier,omitempty"`
- Store bool `json:"store,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- Text json.RawMessage `json:"text,omitempty"`
- ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
- Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
- TopP float64 `json:"top_p,omitempty"`
- Truncation string `json:"truncation,omitempty"`
- User string `json:"user,omitempty"`
- MaxToolCalls uint `json:"max_tool_calls,omitempty"`
- Prompt json.RawMessage `json:"prompt,omitempty"`
+ // 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
+ ServiceTier string `json:"service_tier,omitempty"`
+ Store json.RawMessage `json:"store,omitempty"`
+ PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ Text json.RawMessage `json:"text,omitempty"`
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
+ TopP float64 `json:"top_p,omitempty"`
+ Truncation string `json:"truncation,omitempty"`
+ User string `json:"user,omitempty"`
+ MaxToolCalls uint `json:"max_tool_calls,omitempty"`
+ Prompt json.RawMessage `json:"prompt,omitempty"`
}
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
diff --git a/dto/openai_response.go b/dto/openai_response.go
index 966748cb..6353c15f 100644
--- a/dto/openai_response.go
+++ b/dto/openai_response.go
@@ -6,6 +6,10 @@ import (
"one-api/types"
)
+const (
+ ResponsesOutputTypeImageGenerationCall = "image_generation_call"
+)
+
type SimpleResponse struct {
Usage `json:"usage"`
Error any `json:"error"`
@@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
+func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
+ if len(o.Output) == 0 {
+ return false
+ }
+ for _, output := range o.Output {
+ if output.Type == ResponsesOutputTypeImageGenerationCall {
+ return true
+ }
+ }
+ return false
+}
+
+func (o *OpenAIResponsesResponse) GetQuality() string {
+ if len(o.Output) == 0 {
+ return ""
+ }
+ for _, output := range o.Output {
+ if output.Type == ResponsesOutputTypeImageGenerationCall {
+ return output.Quality
+ }
+ }
+ return ""
+}
+
+func (o *OpenAIResponsesResponse) GetSize() string {
+ if len(o.Output) == 0 {
+ return ""
+ }
+ for _, output := range o.Output {
+ if output.Type == ResponsesOutputTypeImageGenerationCall {
+ return output.Size
+ }
+ }
+ return ""
+}
+
type IncompleteDetails struct {
Reasoning string `json:"reasoning"`
}
@@ -283,6 +323,8 @@ type ResponsesOutput struct {
Status string `json:"status"`
Role string `json:"role"`
Content []ResponsesOutputContent `json:"content"`
+ Quality string `json:"quality"`
+ Size string `json:"size"`
}
type ResponsesOutputContent struct {
diff --git a/dto/user_settings.go b/dto/user_settings.go
index 89dd926e..16ce7b98 100644
--- a/dto/user_settings.go
+++ b/dto/user_settings.go
@@ -7,6 +7,9 @@ type UserSetting struct {
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
+ GotifyUrl string `json:"gotify_url,omitempty"` // GotifyUrl Gotify服务器地址
+ GotifyToken string `json:"gotify_token,omitempty"` // GotifyToken Gotify应用令牌
+ GotifyPriority int `json:"gotify_priority"` // GotifyPriority Gotify消息优先级
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
@@ -16,4 +19,5 @@ var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
NotifyTypeBark = "bark" // Bark 推送
+ NotifyTypeGotify = "gotify" // Gotify 推送
)
diff --git a/go.mod b/go.mod
index 501d966d..66a452ce 100644
--- a/go.mod
+++ b/go.mod
@@ -1,7 +1,9 @@
module one-api
// +heroku goVersion go1.18
-go 1.23.4
+go 1.24.0
+
+toolchain go1.24.6
require (
github.com/Calcium-Ion/go-epay v0.0.4
@@ -20,6 +22,7 @@ require (
github.com/glebarez/sqlite v1.9.0
github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5
+ github.com/go-webauthn/webauthn v0.14.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
@@ -35,10 +38,10 @@ require (
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.6.2
- golang.org/x/crypto v0.35.0
+ golang.org/x/crypto v0.42.0
golang.org/x/image v0.23.0
- golang.org/x/net v0.35.0
- golang.org/x/sync v0.11.0
+ golang.org/x/net v0.43.0
+ golang.org/x/sync v0.17.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -58,6 +61,7 @@ require (
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
+ github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
@@ -65,8 +69,11 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
+ github.com/go-webauthn/x v0.1.25 // indirect
github.com/goccy/go-json v0.10.2 // indirect
+ github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
+ github.com/google/go-tpm v0.9.5 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -91,11 +98,12 @@ require (
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
+ github.com/x448/float16 v0.8.4 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
- golang.org/x/sys v0.30.0 // indirect
- golang.org/x/text v0.22.0 // indirect
+ golang.org/x/sys v0.36.0 // indirect
+ golang.org/x/text v0.29.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect
diff --git a/go.sum b/go.sum
index 189d09de..a62b8321 100644
--- a/go.sum
+++ b/go.sum
@@ -47,6 +47,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
+github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
+github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
@@ -89,16 +91,24 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
+github.com/go-webauthn/webauthn v0.14.0 h1:ZLNPUgPcDlAeoxe+5umWG/tEeCoQIDr7gE2Zx2QnhL0=
+github.com/go-webauthn/webauthn v0.14.0/go.mod h1:QZzPFH3LJ48u5uEPAu+8/nWJImoLBWM7iAH/kSVSo6k=
+github.com/go-webauthn/x v0.1.25 h1:g/0noooIGcz/yCVqebcFgNnGIgBlJIccS+LYAa+0Z88=
+github.com/go-webauthn/x v0.1.25/go.mod h1:ieblaPY1/BVCV0oQTsA/VAo08/TWayQuJuo5Q+XxmTY=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
+github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
+github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
+github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
@@ -200,8 +210,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
@@ -229,27 +240,31 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
+github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
+github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
+go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
+go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
-golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
+golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
+golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
-golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
+golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
+golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
-golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
+golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -261,14 +276,14 @@ golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
-golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
+golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
-golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
+golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
+golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
diff --git a/main.go b/main.go
index cc2288a6..ba96d209 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "bytes"
"embed"
"fmt"
"log"
@@ -16,6 +17,8 @@ import (
"one-api/setting/ratio_setting"
"os"
"strconv"
+ "strings"
+ "time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
@@ -33,6 +36,7 @@ var buildFS embed.FS
var indexPage []byte
func main() {
+ startTime := time.Now()
err := InitResources()
if err != nil {
@@ -145,11 +149,31 @@ func main() {
})
server.Use(sessions.Sessions("session", store))
+ analyticsInjectBuilder := &strings.Builder{}
+ if os.Getenv("UMAMI_WEBSITE_ID") != "" {
+ umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
+ umamiScriptURL := os.Getenv("UMAMI_SCRIPT_URL")
+ if umamiScriptURL == "" {
+ umamiScriptURL = "https://analytics.umami.is/script.js"
+ }
+ analyticsInjectBuilder.WriteString("")
+ }
+ analyticsInject := analyticsInjectBuilder.String()
+ indexPage = bytes.ReplaceAll(indexPage, []byte("
\n"), []byte(analyticsInject))
+
router.SetRouter(server, buildFS, indexPage)
var port = os.Getenv("PORT")
if port == "" {
port = strconv.Itoa(*common.Port)
}
+
+ // Log startup success message
+ common.LogStartupSuccess(startTime, port)
+
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
@@ -161,8 +185,9 @@ func InitResources() error {
// This is a placeholder function for future resource initialization
err := godotenv.Load(".env")
if err != nil {
- common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
- common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+ if common.DebugEnabled {
+ common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+ }
}
// 加载环境变量
diff --git a/middleware/secure_verification.go b/middleware/secure_verification.go
new file mode 100644
index 00000000..19fae9a5
--- /dev/null
+++ b/middleware/secure_verification.go
@@ -0,0 +1,131 @@
+package middleware
+
+import (
+ "net/http"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ // SecureVerificationSessionKey 安全验证的 session key(与 controller 保持一致)
+ SecureVerificationSessionKey = "secure_verified_at"
+ // SecureVerificationTimeout 验证有效期(秒)
+ SecureVerificationTimeout = 300 // 5分钟
+)
+
+// SecureVerificationRequired 安全验证中间件
+// 检查用户是否在有效时间内通过了安全验证
+// 如果未验证或验证已过期,返回 401 错误
+func SecureVerificationRequired() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 检查用户是否已登录
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "未登录",
+ })
+ c.Abort()
+ return
+ }
+
+ // 检查 session 中的验证时间戳
+ session := sessions.Default(c)
+ verifiedAtRaw := session.Get(SecureVerificationSessionKey)
+
+ if verifiedAtRaw == nil {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "需要安全验证",
+ "code": "VERIFICATION_REQUIRED",
+ })
+ c.Abort()
+ return
+ }
+
+ verifiedAt, ok := verifiedAtRaw.(int64)
+ if !ok {
+ // session 数据格式错误
+ session.Delete(SecureVerificationSessionKey)
+ _ = session.Save()
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "验证状态异常,请重新验证",
+ "code": "VERIFICATION_INVALID",
+ })
+ c.Abort()
+ return
+ }
+
+ // 检查验证是否过期
+ elapsed := time.Now().Unix() - verifiedAt
+ if elapsed >= SecureVerificationTimeout {
+ // 验证已过期,清除 session
+ session.Delete(SecureVerificationSessionKey)
+ _ = session.Save()
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "验证已过期,请重新验证",
+ "code": "VERIFICATION_EXPIRED",
+ })
+ c.Abort()
+ return
+ }
+
+ // 验证有效,继续处理请求
+ c.Next()
+ }
+}
+
+// OptionalSecureVerification 可选的安全验证中间件
+// 如果用户已验证,则在 context 中设置标记,但不阻止请求继续
+// 用于某些需要区分是否已验证的场景
+func OptionalSecureVerification() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.Set("secure_verified", false)
+ c.Next()
+ return
+ }
+
+ session := sessions.Default(c)
+ verifiedAtRaw := session.Get(SecureVerificationSessionKey)
+
+ if verifiedAtRaw == nil {
+ c.Set("secure_verified", false)
+ c.Next()
+ return
+ }
+
+ verifiedAt, ok := verifiedAtRaw.(int64)
+ if !ok {
+ c.Set("secure_verified", false)
+ c.Next()
+ return
+ }
+
+ elapsed := time.Now().Unix() - verifiedAt
+ if elapsed >= SecureVerificationTimeout {
+ session.Delete(SecureVerificationSessionKey)
+ _ = session.Save()
+ c.Set("secure_verified", false)
+ c.Next()
+ return
+ }
+
+ c.Set("secure_verified", true)
+ c.Set("secure_verified_at", verifiedAt)
+ c.Next()
+ }
+}
+
+// ClearSecureVerification 清除安全验证状态
+// 用于用户登出或需要强制重新验证的场景
+func ClearSecureVerification(c *gin.Context) {
+ session := sessions.Default(c)
+ session.Delete(SecureVerificationSessionKey)
+ _ = session.Save()
+}
diff --git a/model/channel.go b/model/channel.go
index a61b3ecc..534e2f3f 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -42,7 +42,6 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
- OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
@@ -51,6 +50,8 @@ type Channel struct {
// add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
+ OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
+
// cache info
Keys []string `json:"-" gorm:"-"`
}
diff --git a/model/main.go b/model/main.go
index 1a38d371..14384caf 100644
--- a/model/main.go
+++ b/model/main.go
@@ -251,6 +251,7 @@ func migrateDB() error {
&Channel{},
&Token{},
&User{},
+ &PasskeyCredential{},
&Option{},
&Redemption{},
&Ability{},
@@ -283,6 +284,7 @@ func migrateDBFast() error {
{&Channel{}, "Channel"},
{&Token{}, "Token"},
{&User{}, "User"},
+ {&PasskeyCredential{}, "PasskeyCredential"},
{&Option{}, "Option"},
{&Redemption{}, "Redemption"},
{&Ability{}, "Ability"},
diff --git a/model/option.go b/model/option.go
index 2121710c..9ace8fec 100644
--- a/model/option.go
+++ b/model/option.go
@@ -6,6 +6,7 @@ import (
"one-api/setting/config"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
+ "one-api/setting/system_setting"
"strconv"
"strings"
"time"
@@ -66,26 +67,27 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
- common.OptionMap["WorkerUrl"] = setting.WorkerUrl
- common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
- common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
+ common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
+ common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
+ common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
- common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
- common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
- common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
+ common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
+ common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
+ common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
common.OptionMap["StripePriceId"] = setting.StripePriceId
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
+ common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
- common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
+ common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -111,6 +113,9 @@ func InitOptionMap() {
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
+ common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
+ common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
+ common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -271,7 +276,7 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
- setting.WorkerAllowHttpImageRequestEnabled = boolValue
+ system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled":
@@ -293,29 +298,29 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken":
common.SMTPToken = value
case "ServerAddress":
- setting.ServerAddress = value
+ system_setting.ServerAddress = value
case "WorkerUrl":
- setting.WorkerUrl = value
+ system_setting.WorkerUrl = value
case "WorkerValidKey":
- setting.WorkerValidKey = value
+ system_setting.WorkerValidKey = value
case "PayAddress":
- setting.PayAddress = value
+ operation_setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
case "AutoGroups":
err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
- setting.CustomCallbackAddress = value
+ operation_setting.CustomCallbackAddress = value
case "EpayId":
- setting.EpayId = value
+ operation_setting.EpayId = value
case "EpayKey":
- setting.EpayKey = value
+ operation_setting.EpayKey = value
case "Price":
- setting.Price, _ = strconv.ParseFloat(value, 64)
+ operation_setting.Price, _ = strconv.ParseFloat(value, 64)
case "USDExchangeRate":
- setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
+ operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
- setting.MinTopUp, _ = strconv.Atoi(value)
+ operation_setting.MinTopUp, _ = strconv.Atoi(value)
case "StripeApiSecret":
setting.StripeApiSecret = value
case "StripeWebhookSecret":
@@ -326,6 +331,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "StripeMinTopUp":
setting.StripeMinTopUp, _ = strconv.Atoi(value)
+ case "StripePromotionCodesEnabled":
+ setting.StripePromotionCodesEnabled = value == "true"
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
@@ -396,6 +403,12 @@ func updateOptionMap(key string, value string) (err error) {
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = ratio_setting.UpdateCacheRatioByJSONString(value)
+ case "ImageRatio":
+ err = ratio_setting.UpdateImageRatioByJSONString(value)
+ case "AudioRatio":
+ err = ratio_setting.UpdateAudioRatioByJSONString(value)
+ case "AudioCompletionRatio":
+ err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -413,7 +426,7 @@ func updateOptionMap(key string, value string) (err error) {
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods":
- err = setting.UpdatePayMethodsByJsonString(value)
+ err = operation_setting.UpdatePayMethodsByJsonString(value)
}
return err
}
diff --git a/model/passkey.go b/model/passkey.go
new file mode 100644
index 00000000..5b2a1547
--- /dev/null
+++ b/model/passkey.go
@@ -0,0 +1,209 @@
+package model
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "one-api/common"
+ "strings"
+ "time"
+
+ "github.com/go-webauthn/webauthn/protocol"
+ "github.com/go-webauthn/webauthn/webauthn"
+ "gorm.io/gorm"
+)
+
+var (
+ ErrPasskeyNotFound = errors.New("passkey credential not found")
+ ErrFriendlyPasskeyNotFound = errors.New("Passkey 验证失败,请重试或联系管理员")
+)
+
+type PasskeyCredential struct {
+ ID int `json:"id" gorm:"primaryKey"`
+ UserID int `json:"user_id" gorm:"uniqueIndex;not null"`
+ CredentialID string `json:"credential_id" gorm:"type:varchar(512);uniqueIndex;not null"` // base64 encoded
+ PublicKey string `json:"public_key" gorm:"type:text;not null"` // base64 encoded
+ AttestationType string `json:"attestation_type" gorm:"type:varchar(255)"`
+ AAGUID string `json:"aaguid" gorm:"type:varchar(512)"` // base64 encoded
+ SignCount uint32 `json:"sign_count" gorm:"default:0"`
+ CloneWarning bool `json:"clone_warning"`
+ UserPresent bool `json:"user_present"`
+ UserVerified bool `json:"user_verified"`
+ BackupEligible bool `json:"backup_eligible"`
+ BackupState bool `json:"backup_state"`
+ Transports string `json:"transports" gorm:"type:text"`
+ Attachment string `json:"attachment" gorm:"type:varchar(32)"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
+}
+
+func (p *PasskeyCredential) TransportList() []protocol.AuthenticatorTransport {
+ if p == nil || strings.TrimSpace(p.Transports) == "" {
+ return nil
+ }
+ var transports []string
+ if err := json.Unmarshal([]byte(p.Transports), &transports); err != nil {
+ return nil
+ }
+ result := make([]protocol.AuthenticatorTransport, 0, len(transports))
+ for _, transport := range transports {
+ result = append(result, protocol.AuthenticatorTransport(transport))
+ }
+ return result
+}
+
+func (p *PasskeyCredential) SetTransports(list []protocol.AuthenticatorTransport) {
+ if len(list) == 0 {
+ p.Transports = ""
+ return
+ }
+ stringList := make([]string, len(list))
+ for i, transport := range list {
+ stringList[i] = string(transport)
+ }
+ encoded, err := json.Marshal(stringList)
+ if err != nil {
+ return
+ }
+ p.Transports = string(encoded)
+}
+
+func (p *PasskeyCredential) ToWebAuthnCredential() webauthn.Credential {
+ flags := webauthn.CredentialFlags{
+ UserPresent: p.UserPresent,
+ UserVerified: p.UserVerified,
+ BackupEligible: p.BackupEligible,
+ BackupState: p.BackupState,
+ }
+
+ credID, _ := base64.StdEncoding.DecodeString(p.CredentialID)
+ pubKey, _ := base64.StdEncoding.DecodeString(p.PublicKey)
+ aaguid, _ := base64.StdEncoding.DecodeString(p.AAGUID)
+
+ return webauthn.Credential{
+ ID: credID,
+ PublicKey: pubKey,
+ AttestationType: p.AttestationType,
+ Transport: p.TransportList(),
+ Flags: flags,
+ Authenticator: webauthn.Authenticator{
+ AAGUID: aaguid,
+ SignCount: p.SignCount,
+ CloneWarning: p.CloneWarning,
+ Attachment: protocol.AuthenticatorAttachment(p.Attachment),
+ },
+ }
+}
+
+func NewPasskeyCredentialFromWebAuthn(userID int, credential *webauthn.Credential) *PasskeyCredential {
+ if credential == nil {
+ return nil
+ }
+ passkey := &PasskeyCredential{
+ UserID: userID,
+ CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
+ PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
+ AttestationType: credential.AttestationType,
+ AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
+ SignCount: credential.Authenticator.SignCount,
+ CloneWarning: credential.Authenticator.CloneWarning,
+ UserPresent: credential.Flags.UserPresent,
+ UserVerified: credential.Flags.UserVerified,
+ BackupEligible: credential.Flags.BackupEligible,
+ BackupState: credential.Flags.BackupState,
+ Attachment: string(credential.Authenticator.Attachment),
+ }
+ passkey.SetTransports(credential.Transport)
+ return passkey
+}
+
+func (p *PasskeyCredential) ApplyValidatedCredential(credential *webauthn.Credential) {
+ if credential == nil || p == nil {
+ return
+ }
+ p.CredentialID = base64.StdEncoding.EncodeToString(credential.ID)
+ p.PublicKey = base64.StdEncoding.EncodeToString(credential.PublicKey)
+ p.AttestationType = credential.AttestationType
+ p.AAGUID = base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID)
+ p.SignCount = credential.Authenticator.SignCount
+ p.CloneWarning = credential.Authenticator.CloneWarning
+ p.UserPresent = credential.Flags.UserPresent
+ p.UserVerified = credential.Flags.UserVerified
+ p.BackupEligible = credential.Flags.BackupEligible
+ p.BackupState = credential.Flags.BackupState
+ p.Attachment = string(credential.Authenticator.Attachment)
+ p.SetTransports(credential.Transport)
+}
+
+func GetPasskeyByUserID(userID int) (*PasskeyCredential, error) {
+ if userID == 0 {
+ common.SysLog("GetPasskeyByUserID: empty user ID")
+ return nil, ErrFriendlyPasskeyNotFound
+ }
+ var credential PasskeyCredential
+ if err := DB.Where("user_id = ?", userID).First(&credential).Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ // 未找到记录是正常情况(用户未绑定),返回 ErrPasskeyNotFound 而不记录日志
+ return nil, ErrPasskeyNotFound
+ }
+ // 只有真正的数据库错误才记录日志
+ common.SysLog(fmt.Sprintf("GetPasskeyByUserID: database error for user %d: %v", userID, err))
+ return nil, ErrFriendlyPasskeyNotFound
+ }
+ return &credential, nil
+}
+
+func GetPasskeyByCredentialID(credentialID []byte) (*PasskeyCredential, error) {
+ if len(credentialID) == 0 {
+ common.SysLog("GetPasskeyByCredentialID: empty credential ID")
+ return nil, ErrFriendlyPasskeyNotFound
+ }
+
+ credIDStr := base64.StdEncoding.EncodeToString(credentialID)
+ var credential PasskeyCredential
+ if err := DB.Where("credential_id = ?", credIDStr).First(&credential).Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: passkey not found for credential ID length %d", len(credentialID)))
+ return nil, ErrFriendlyPasskeyNotFound
+ }
+ common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: database error for credential ID: %v", err))
+ return nil, ErrFriendlyPasskeyNotFound
+ }
+
+ return &credential, nil
+}
+
+func UpsertPasskeyCredential(credential *PasskeyCredential) error {
+ if credential == nil {
+ common.SysLog("UpsertPasskeyCredential: nil credential provided")
+ return fmt.Errorf("Passkey 保存失败,请重试")
+ }
+ return DB.Transaction(func(tx *gorm.DB) error {
+ // 使用Unscoped()进行硬删除,避免唯一索引冲突
+ if err := tx.Unscoped().Where("user_id = ?", credential.UserID).Delete(&PasskeyCredential{}).Error; err != nil {
+ common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to delete existing credential for user %d: %v", credential.UserID, err))
+ return fmt.Errorf("Passkey 保存失败,请重试")
+ }
+ if err := tx.Create(credential).Error; err != nil {
+ common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to create credential for user %d: %v", credential.UserID, err))
+ return fmt.Errorf("Passkey 保存失败,请重试")
+ }
+ return nil
+ })
+}
+
+func DeletePasskeyByUserID(userID int) error {
+ if userID == 0 {
+ common.SysLog("DeletePasskeyByUserID: empty user ID")
+ return fmt.Errorf("删除失败,请重试")
+ }
+ // 使用Unscoped()进行硬删除,避免唯一索引冲突
+ if err := DB.Unscoped().Where("user_id = ?", userID).Delete(&PasskeyCredential{}).Error; err != nil {
+ common.SysLog(fmt.Sprintf("DeletePasskeyByUserID: failed to delete passkey for user %d: %v", userID, err))
+ return fmt.Errorf("删除失败,请重试")
+ }
+ return nil
+}
diff --git a/model/task.go b/model/task.go
index 4c64a529..8e2b6d0b 100644
--- a/model/task.go
+++ b/model/task.go
@@ -24,7 +24,7 @@ type Task struct {
ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
CreatedAt int64 `json:"created_at" gorm:"index"`
UpdatedAt int64 `json:"updated_at"`
- TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id
+ TaskID string `json:"task_id" gorm:"type:varchar(191);index"` // 第三方id,不一定有/ song id\ Task id
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
UserId int `json:"user_id" gorm:"index"`
ChannelId int `json:"channel_id" gorm:"index"`
diff --git a/model/user.go b/model/user.go
index ea0584c5..d3e40fa3 100644
--- a/model/user.go
+++ b/model/user.go
@@ -18,7 +18,7 @@ import (
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
Id int `json:"id"`
- Username string `json:"username" gorm:"unique;index" validate:"max=12"`
+ Username string `json:"username" gorm:"unique;index" validate:"max=20"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
diff --git a/relay/audio_handler.go b/relay/audio_handler.go
index 711cc7a9..1357e381 100644
--- a/relay/audio_handler.go
+++ b/relay/audio_handler.go
@@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index a50d5bdb..79a0f706 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -264,9 +264,9 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
}
resp, err := client.Do(req)
-
if err != nil {
- return nil, err
+ logger.LogError(c, "do request failed: "+err.Error())
+ return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
}
if resp == nil {
return nil, errors.New("resp is nil")
diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go
index 1526a7f7..6202c9fc 100644
--- a/relay/channel/aws/adaptor.go
+++ b/relay/channel/aws/adaptor.go
@@ -52,6 +52,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ anthropicBeta := c.Request.Header.Get("anthropic-beta")
+ if anthropicBeta != "" {
+ req.Set("anthropic-beta", anthropicBeta)
+ }
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
return nil
}
@@ -60,7 +64,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
+ // 检查是否为Nova模型
+ if isNovaModel(request.Model) {
+ novaReq := convertToNovaRequest(request)
+ c.Set("request_model", request.Model)
+ c.Set("converted_request", novaReq)
+ c.Set("is_nova_model", true)
+ return novaReq, nil
+ }
+ // 原有的Claude模型处理逻辑
var claudeReq *dto.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
@@ -69,6 +82,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
+ c.Set("is_nova_model", false)
return claudeReq, err
}
diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go
index 3f8800b1..45112d23 100644
--- a/relay/channel/aws/constants.go
+++ b/relay/channel/aws/constants.go
@@ -1,5 +1,7 @@
package aws
+import "strings"
+
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
@@ -14,6 +16,16 @@ var awsModelIDMap = map[string]string{
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
+ "claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
+ // Nova models
+ "nova-micro-v1:0": "amazon.nova-micro-v1:0",
+ "nova-lite-v1:0": "amazon.nova-lite-v1:0",
+ "nova-pro-v1:0": "amazon.nova-pro-v1:0",
+ "nova-premier-v1:0": "amazon.nova-premier-v1:0",
+ "nova-canvas-v1:0": "amazon.nova-canvas-v1:0",
+ "nova-reel-v1:0": "amazon.nova-reel-v1:0",
+ "nova-reel-v1:1": "amazon.nova-reel-v1:1",
+ "nova-sonic-v1:0": "amazon.nova-sonic-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -58,6 +70,48 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-opus-4-1-20250805-v1:0": {
"us": true,
},
+ "anthropic.claude-sonnet-4-5-20250929-v1:0": {
+ "us": true,
+ "ap": true,
+ "eu": true,
+ },
+ // Nova models - all support three major regions
+ "amazon.nova-micro-v1:0": {
+ "us": true,
+ "eu": true,
+ "apac": true,
+ },
+ "amazon.nova-lite-v1:0": {
+ "us": true,
+ "eu": true,
+ "apac": true,
+ },
+ "amazon.nova-pro-v1:0": {
+ "us": true,
+ "eu": true,
+ "apac": true,
+ },
+ "amazon.nova-premier-v1:0": {
+ "us": true,
+ },
+ "amazon.nova-canvas-v1:0": {
+ "us": true,
+ "eu": true,
+ "apac": true,
+ },
+ "amazon.nova-reel-v1:0": {
+ "us": true,
+ "eu": true,
+ "apac": true,
+ },
+ "amazon.nova-reel-v1:1": {
+ "us": true,
+ },
+ "amazon.nova-sonic-v1:0": {
+ "us": true,
+ "eu": true,
+ "apac": true,
+ },
}
var awsRegionCrossModelPrefixMap = map[string]string{
@@ -67,3 +121,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
}
var ChannelName = "aws"
+
+// 判断是否为Nova模型
+func isNovaModel(modelId string) bool {
+ return strings.HasPrefix(modelId, "nova-")
+}
diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go
index 0188c30a..9c9fe946 100644
--- a/relay/channel/aws/dto.go
+++ b/relay/channel/aws/dto.go
@@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
Thinking: req.Thinking,
}
}
+
+// NovaMessage Nova模型使用messages-v1格式
+type NovaMessage struct {
+ Role string `json:"role"`
+ Content []NovaContent `json:"content"`
+}
+
+type NovaContent struct {
+ Text string `json:"text"`
+}
+
+type NovaRequest struct {
+ SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
+ Messages []NovaMessage `json:"messages"` // 对话消息列表
+ InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
+}
+
+type NovaInferenceConfig struct {
+ MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
+ Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
+ TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
+ TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
+ StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
+}
+
+// 转换OpenAI请求为Nova格式
+func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
+ novaMessages := make([]NovaMessage, len(req.Messages))
+ for i, msg := range req.Messages {
+ novaMessages[i] = NovaMessage{
+ Role: msg.Role,
+ Content: []NovaContent{{Text: msg.StringContent()}},
+ }
+ }
+
+ novaReq := &NovaRequest{
+ SchemaVersion: "messages-v1",
+ Messages: novaMessages,
+ }
+
+ // 设置推理配置
+ if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
+ novaReq.InferenceConfig = &NovaInferenceConfig{}
+ if req.MaxTokens != 0 {
+ novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
+ }
+ if req.Temperature != nil && *req.Temperature != 0 {
+ novaReq.InferenceConfig.Temperature = *req.Temperature
+ }
+ if req.TopP != 0 {
+ novaReq.InferenceConfig.TopP = req.TopP
+ }
+ if req.TopK != 0 {
+ novaReq.InferenceConfig.TopK = req.TopK
+ }
+ if req.Stop != nil {
+ if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
+ novaReq.InferenceConfig.StopSequences = stopSequences
+ }
+ }
+ }
+
+ return novaReq
+}
+
+// parseStopSequences 解析停止序列,支持字符串或字符串数组
+func parseStopSequences(stop any) []string {
+ if stop == nil {
+ return nil
+ }
+
+ switch v := stop.(type) {
+ case string:
+ if v != "" {
+ return []string{v}
+ }
+ case []string:
+ return v
+ case []interface{}:
+ var sequences []string
+ for _, item := range v {
+ if str, ok := item.(string); ok && str != "" {
+ sequences = append(sequences, str)
+ }
+ }
+ return sequences
+ }
+ return nil
+}
diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go
index 26e234fa..eef26855 100644
--- a/relay/channel/aws/relay-aws.go
+++ b/relay/channel/aws/relay-aws.go
@@ -1,6 +1,7 @@
package aws
import (
+ "encoding/json"
"fmt"
"net/http"
"one-api/common"
@@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
}
awsModelId := awsModelID(c.GetString("request_model"))
+ // 检查是否为Nova模型
+ isNova, _ := c.Get("is_nova_model")
+ if isNova == true {
+ // Nova模型也支持跨区域
+ awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+ canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+ if canCrossRegion {
+ awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+ }
+ return handleNovaRequest(c, awsCli, info, awsModelId)
+ }
+ // 原有的Claude处理逻辑
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
@@ -209,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
return nil, claudeInfo.Usage
}
+
+// Nova模型处理函数
+func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
+ novaReq_, ok := c.Get("converted_request")
+ if !ok {
+ return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
+ }
+ novaReq := novaReq_.(*NovaRequest)
+
+ // 使用InvokeModel API,但使用Nova格式的请求体
+ awsReq := &bedrockruntime.InvokeModelInput{
+ ModelId: aws.String(awsModelId),
+ Accept: aws.String("application/json"),
+ ContentType: aws.String("application/json"),
+ }
+
+ reqBody, err := json.Marshal(novaReq)
+ if err != nil {
+ return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
+ }
+ awsReq.Body = reqBody
+
+ awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
+ if err != nil {
+ return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
+ }
+
+ // 解析Nova响应
+ var novaResp struct {
+ Output struct {
+ Message struct {
+ Content []struct {
+ Text string `json:"text"`
+ } `json:"content"`
+ } `json:"message"`
+ } `json:"output"`
+ Usage struct {
+ InputTokens int `json:"inputTokens"`
+ OutputTokens int `json:"outputTokens"`
+ TotalTokens int `json:"totalTokens"`
+ } `json:"usage"`
+ }
+
+ if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
+ return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
+ }
+
+ // 构造OpenAI格式响应
+ response := dto.OpenAITextResponse{
+ Id: helper.GetResponseID(c),
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Model: info.UpstreamModelName,
+ Choices: []dto.OpenAITextResponseChoice{{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: novaResp.Output.Message.Content[0].Text,
+ },
+ FinishReason: "stop",
+ }},
+ Usage: dto.Usage{
+ PromptTokens: novaResp.Usage.InputTokens,
+ CompletionTokens: novaResp.Usage.OutputTokens,
+ TotalTokens: novaResp.Usage.TotalTokens,
+ },
+ }
+
+ c.JSON(http.StatusOK, response)
+ return nil, &response.Usage
+}
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
index 959327e1..362f09e7 100644
--- a/relay/channel/claude/adaptor.go
+++ b/relay/channel/claude/adaptor.go
@@ -52,11 +52,16 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ baseURL := ""
if a.RequestMode == RequestModeMessage {
- return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil
+ baseURL = fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
} else {
- return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil
+ baseURL = fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl)
}
+ if info.IsClaudeBetaQuery {
+ baseURL = baseURL + "?beta=true"
+ }
+ return baseURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -67,6 +72,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
anthropicVersion = "2023-06-01"
}
req.Set("anthropic-version", anthropicVersion)
+ anthropicBeta := c.Request.Header.Get("anthropic-beta")
+ if anthropicBeta != "" {
+ req.Set("anthropic-beta", anthropicBeta)
+ }
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
return nil
}
diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go
index a23543d2..d0b36fe4 100644
--- a/relay/channel/claude/constants.go
+++ b/relay/channel/claude/constants.go
@@ -19,6 +19,8 @@ var ModelList = []string{
"claude-opus-4-20250514-thinking",
"claude-opus-4-1-20250805",
"claude-opus-4-1-20250805-thinking",
+ "claude-sonnet-4-5-20250929",
+ "claude-sonnet-4-5-20250929-thinking",
}
var ChannelName = "claude"
diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go
index 17d732ab..962f8794 100644
--- a/relay/channel/deepseek/adaptor.go
+++ b/relay/channel/deepseek/adaptor.go
@@ -3,17 +3,17 @@ package deepseek
import (
"errors"
"fmt"
+ "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
+ "one-api/relay/channel/claude"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
-
- "github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
+ adaptor := claude.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
@@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.ChannelBaseUrl
- if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
- fimBaseUrl += "/beta"
- }
- switch info.RelayMode {
- case constant.RelayModeCompletions:
- return fmt.Sprintf("%s/completions", fimBaseUrl), nil
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
default:
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
+ fimBaseUrl += "/beta"
+ }
+ switch info.RelayMode {
+ case constant.RelayModeCompletions:
+ return fmt.Sprintf("%s/completions", fimBaseUrl), nil
+ default:
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ }
}
}
@@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ if info.IsStream {
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ } else {
+ return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
+ }
+ default:
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
}
- return
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index 4968f78f..57542aa5 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
- if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
- strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
+ if strings.Contains(info.RequestURLPath, ":embedContent") ||
+ strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream {
diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go
index 564b8690..974a22f5 100644
--- a/relay/channel/gemini/relay-gemini-native.go
+++ b/relay/channel/gemini/relay-gemini-native.go
@@ -46,32 +46,6 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
- imageOutputCounts := 0
- for _, candidate := range geminiResponse.Candidates {
- for _, part := range candidate.Content.Parts {
- if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
- imageOutputCounts++
- }
- }
- }
- if imageOutputCounts != 0 {
- usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
- usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
- c.Set("gemini_image_tokens", imageOutputCounts*1290)
- }
- }
-
- // if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
- // for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
- // if detail.Modality == "IMAGE" {
- // usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
- // usage.TotalTokens = usage.TotalTokens - detail.TokenCount
- // c.Set("gemini_image_tokens", detail.TokenCount)
- // }
- // }
- // }
-
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
@@ -162,16 +136,6 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
-
- if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
- for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
- if detail.Modality == "IMAGE" {
- usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
- usage.TotalTokens = usage.TotalTokens - detail.TokenCount
- c.Set("gemini_image_tokens", detail.TokenCount)
- }
- }
- }
}
// 直接发送 GeminiChatResponse 响应
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index eb4afbae..c8e9c757 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -23,6 +23,7 @@ import (
"github.com/gin-gonic/gin"
)
+// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
var geminiSupportedMimeTypes = map[string]bool{
"application/pdf": true,
"audio/mpeg": true,
@@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
+ "image/webp": true,
"text/plain": true,
"video/mov": true,
"video/mpeg": true,
@@ -243,6 +245,7 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
+ urlContext := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
@@ -252,6 +255,10 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
codeExecution = true
continue
}
+ if tool.Function.Name == "urlContext" {
+ urlContext = true
+ continue
+ }
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
@@ -279,6 +286,11 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
GoogleSearch: make(map[string]string),
})
}
+ if urlContext {
+ geminiTools = append(geminiTools, dto.GeminiChatTool{
+ URLContext: make(map[string]string),
+ })
+ }
if len(functions) > 0 {
geminiTools = append(geminiTools, dto.GeminiChatTool{
FunctionDeclarations: functions,
diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go
index a383728f..dbfe314d 100644
--- a/relay/channel/jina/adaptor.go
+++ b/relay/channel/jina/adaptor.go
@@ -76,6 +76,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ request.EncodingFormat = ""
return request, nil
}
diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go
index e290c239..f24976bb 100644
--- a/relay/channel/moonshot/adaptor.go
+++ b/relay/channel/moonshot/adaptor.go
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
+ adaptor := claude.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index d6b5b697..bafe73b9 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -10,6 +10,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
+ "strings"
"github.com/gin-gonic/gin"
)
@@ -17,10 +18,7 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
openaiAdaptor := openai.Adaptor{}
@@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
- return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
+ // map to ollama chat request (Claude -> OpenAI -> Ollama chat)
+ return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
}
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayFormat == types.RelayFormatClaude {
- return info.ChannelBaseUrl + "/v1/chat/completions", nil
- }
- switch info.RelayMode {
- case relayconstant.RelayModeEmbeddings:
- return info.ChannelBaseUrl + "/api/embed", nil
- default:
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
- }
+ if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
+ if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
+ return info.ChannelBaseUrl + "/api/chat", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
+ if request == nil { return nil, errors.New("request is nil") }
+ // decide generate or chat
+ if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
+ return openAIToGenerate(c, request)
}
- return requestOpenAI2Ollama(c, request)
+ return openAIChatToOllamaChat(c, request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return requestOpenAI2Embeddings(request), nil
}
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
@@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
- usage, err = ollamaEmbeddingHandler(c, info, resp)
+ return ollamaEmbeddingHandler(c, info, resp)
default:
if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
+ return ollamaStreamHandler(c, info, resp)
}
+ return ollamaChatHandler(c, info, resp)
}
- return
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go
index 317c2a4a..45e49ab4 100644
--- a/relay/channel/ollama/dto.go
+++ b/relay/channel/ollama/dto.go
@@ -2,48 +2,69 @@ package ollama
import (
"encoding/json"
- "one-api/dto"
)
-type OllamaRequest struct {
- Model string `json:"model,omitempty"`
- Messages []dto.Message `json:"messages,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Topp float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Stop any `json:"stop,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- Tools []dto.ToolCallRequest `json:"tools,omitempty"`
- ResponseFormat any `json:"response_format,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- Suffix any `json:"suffix,omitempty"`
- StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Think json.RawMessage `json:"think,omitempty"`
+type OllamaChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content,omitempty"`
+ Images []string `json:"images,omitempty"`
+ ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
+ ToolName string `json:"tool_name,omitempty"`
+ Thinking json.RawMessage `json:"thinking,omitempty"`
}
-type Options struct {
- Seed int `json:"seed,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopK int `json:"top_k,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- NumPredict int `json:"num_predict,omitempty"`
- NumCtx int `json:"num_ctx,omitempty"`
+type OllamaToolFunction struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Parameters interface{} `json:"parameters,omitempty"`
+}
+
+type OllamaTool struct {
+ Type string `json:"type"`
+ Function OllamaToolFunction `json:"function"`
+}
+
+type OllamaToolCall struct {
+ Function struct {
+ Name string `json:"name"`
+ Arguments interface{} `json:"arguments"`
+ } `json:"function"`
+}
+
+type OllamaChatRequest struct {
+ Model string `json:"model"`
+ Messages []OllamaChatMessage `json:"messages"`
+ Tools interface{} `json:"tools,omitempty"`
+ Format interface{} `json:"format,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Options map[string]any `json:"options,omitempty"`
+ KeepAlive interface{} `json:"keep_alive,omitempty"`
+ Think json.RawMessage `json:"think,omitempty"`
+}
+
+type OllamaGenerateRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt,omitempty"`
+ Suffix string `json:"suffix,omitempty"`
+ Images []string `json:"images,omitempty"`
+ Format interface{} `json:"format,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Options map[string]any `json:"options,omitempty"`
+ KeepAlive interface{} `json:"keep_alive,omitempty"`
+ Think json.RawMessage `json:"think,omitempty"`
}
type OllamaEmbeddingRequest struct {
- Model string `json:"model,omitempty"`
- Input []string `json:"input"`
- Options *Options `json:"options,omitempty"`
+ Model string `json:"model"`
+ Input interface{} `json:"input"`
+ Options map[string]any `json:"options,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
}
type OllamaEmbeddingResponse struct {
- Error string `json:"error,omitempty"`
- Model string `json:"model"`
- Embedding [][]float64 `json:"embeddings,omitempty"`
+ Error string `json:"error,omitempty"`
+ Model string `json:"model"`
+ Embeddings [][]float64 `json:"embeddings"`
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
}
+
diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go
index 27c67b4e..3b67f952 100644
--- a/relay/channel/ollama/relay-ollama.go
+++ b/relay/channel/ollama/relay-ollama.go
@@ -1,6 +1,7 @@
package ollama
import (
+ "encoding/json"
"fmt"
"io"
"net/http"
@@ -14,121 +15,176 @@ import (
"github.com/gin-gonic/gin"
)
-func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
- messages := make([]dto.Message, 0, len(request.Messages))
- for _, message := range request.Messages {
- if !message.IsStringContent() {
- mediaMessages := message.ParseContent()
- for j, mediaMessage := range mediaMessages {
- if mediaMessage.Type == dto.ContentTypeImageURL {
- imageUrl := mediaMessage.GetImageMedia()
- // check if not base64
- if strings.HasPrefix(imageUrl.Url, "http") {
- fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
- if err != nil {
- return nil, err
+func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
+ chatReq := &OllamaChatRequest{
+ Model: r.Model,
+ Stream: r.Stream,
+ Options: map[string]any{},
+ Think: r.Think,
+ }
+ if r.ResponseFormat != nil {
+ if r.ResponseFormat.Type == "json" {
+ chatReq.Format = "json"
+ } else if r.ResponseFormat.Type == "json_schema" {
+ if len(r.ResponseFormat.JsonSchema) > 0 {
+ var schema any
+ _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
+ chatReq.Format = schema
+ }
+ }
+ }
+
+ // options mapping
+ if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
+ if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
+ if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
+ if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
+ if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
+ if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
+ if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
+
+ if r.Stop != nil {
+ switch v := r.Stop.(type) {
+ case string:
+ chatReq.Options["stop"] = []string{v}
+ case []string:
+ chatReq.Options["stop"] = v
+ case []any:
+ arr := make([]string,0,len(v))
+ for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
+ if len(arr)>0 { chatReq.Options["stop"] = arr }
+ }
+ }
+
+ if len(r.Tools) > 0 {
+ tools := make([]OllamaTool,0,len(r.Tools))
+ for _, t := range r.Tools {
+ tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
+ }
+ chatReq.Tools = tools
+ }
+
+ chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
+ for _, m := range r.Messages {
+ var textBuilder strings.Builder
+ var images []string
+ if m.IsStringContent() {
+ textBuilder.WriteString(m.StringContent())
+ } else {
+ parts := m.ParseContent()
+ for _, part := range parts {
+ if part.Type == dto.ContentTypeImageURL {
+ img := part.GetImageMedia()
+ if img != nil && img.Url != "" {
+ var base64Data string
+ if strings.HasPrefix(img.Url, "http") {
+ fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
+ if err != nil { return nil, err }
+ base64Data = fileData.Base64Data
+ } else if strings.HasPrefix(img.Url, "data:") {
+ if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
+ } else {
+ base64Data = img.Url
}
- imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
+ if base64Data != "" { images = append(images, base64Data) }
}
- mediaMessage.ImageUrl = imageUrl
- mediaMessages[j] = mediaMessage
+ } else if part.Type == dto.ContentTypeText {
+ textBuilder.WriteString(part.Text)
}
}
- message.SetMediaContent(mediaMessages)
}
- messages = append(messages, dto.Message{
- Role: message.Role,
- Content: message.Content,
- ToolCalls: message.ToolCalls,
- ToolCallId: message.ToolCallId,
- })
+ cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
+ if len(images)>0 { cm.Images = images }
+ if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
+ if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
+ parsed := m.ParseToolCalls()
+ if len(parsed) > 0 {
+ calls := make([]OllamaToolCall,0,len(parsed))
+ for _, tc := range parsed {
+ var args interface{}
+ if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
+ if args==nil { args = map[string]any{} }
+ oc := OllamaToolCall{}
+ oc.Function.Name = tc.Function.Name
+ oc.Function.Arguments = args
+ calls = append(calls, oc)
+ }
+ cm.ToolCalls = calls
+ }
+ }
+ chatReq.Messages = append(chatReq.Messages, cm)
}
- str, ok := request.Stop.(string)
- var Stop []string
- if ok {
- Stop = []string{str}
- } else {
- Stop, _ = request.Stop.([]string)
- }
- ollamaRequest := &OllamaRequest{
- Model: request.Model,
- Messages: messages,
- Stream: request.Stream,
- Temperature: request.Temperature,
- Seed: request.Seed,
- Topp: request.TopP,
- TopK: request.TopK,
- Stop: Stop,
- Tools: request.Tools,
- MaxTokens: request.GetMaxTokens(),
- ResponseFormat: request.ResponseFormat,
- FrequencyPenalty: request.FrequencyPenalty,
- PresencePenalty: request.PresencePenalty,
- Prompt: request.Prompt,
- StreamOptions: request.StreamOptions,
- Suffix: request.Suffix,
- }
- ollamaRequest.Think = request.Think
- return ollamaRequest, nil
+ return chatReq, nil
}
-func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
- return &OllamaEmbeddingRequest{
- Model: request.Model,
- Input: request.ParseInput(),
- Options: &Options{
- Seed: int(request.Seed),
- Temperature: request.Temperature,
- TopP: request.TopP,
- FrequencyPenalty: request.FrequencyPenalty,
- PresencePenalty: request.PresencePenalty,
- },
+// openAIToGenerate converts OpenAI completions request to Ollama generate
+func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
+ gen := &OllamaGenerateRequest{
+ Model: r.Model,
+ Stream: r.Stream,
+ Options: map[string]any{},
+ Think: r.Think,
}
+ // Prompt may be in r.Prompt (string or []any)
+ if r.Prompt != nil {
+ switch v := r.Prompt.(type) {
+ case string:
+ gen.Prompt = v
+ case []any:
+ var sb strings.Builder
+ for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
+ gen.Prompt = sb.String()
+ default:
+ gen.Prompt = fmt.Sprintf("%v", r.Prompt)
+ }
+ }
+ if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
+ if r.ResponseFormat != nil {
+ if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
+ }
+ if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
+ if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
+ if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
+ if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
+ if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
+ if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
+ if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
+ if r.Stop != nil {
+ switch v := r.Stop.(type) {
+ case string: gen.Options["stop"] = []string{v}
+ case []string: gen.Options["stop"] = v
+ case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
+ }
+ }
+ return gen, nil
+}
+
+func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
+ opts := map[string]any{}
+ if r.Temperature != nil { opts["temperature"] = r.Temperature }
+ if r.TopP != 0 { opts["top_p"] = r.TopP }
+ if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
+ if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
+ if r.Seed != 0 { opts["seed"] = int(r.Seed) }
+ if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
+ input := r.ParseInput()
+ if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
+ return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
}
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var ollamaEmbeddingResponse OllamaEmbeddingResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
+ var oResp OllamaEmbeddingResponse
+ body, err := io.ReadAll(resp.Body)
+ if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
service.CloseResponseBodyGracefully(resp)
- err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if ollamaEmbeddingResponse.Error != "" {
- return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
- data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
- data = append(data, dto.OpenAIEmbeddingResponseItem{
- Embedding: flattenedEmbeddings,
- Object: "embedding",
- })
- usage := &dto.Usage{
- TotalTokens: info.PromptTokens,
- CompletionTokens: 0,
- PromptTokens: info.PromptTokens,
- }
- embeddingResponse := &dto.OpenAIEmbeddingResponse{
- Object: "list",
- Data: data,
- Model: info.UpstreamModelName,
- Usage: *usage,
- }
- doResponseBody, err := common.Marshal(embeddingResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- service.IOCopyBytesGracefully(c, resp, doResponseBody)
+ if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+ if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+ data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
+ for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
+ usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
+ embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
+ out, _ := common.Marshal(embResp)
+ service.IOCopyBytesGracefully(c, resp, out)
return usage, nil
}
-func flattenEmbeddings(embeddings [][]float64) []float64 {
- flattened := []float64{}
- for _, row := range embeddings {
- flattened = append(flattened, row...)
- }
- return flattened
-}
diff --git a/relay/channel/ollama/stream.go b/relay/channel/ollama/stream.go
new file mode 100644
index 00000000..964f11d9
--- /dev/null
+++ b/relay/channel/ollama/stream.go
@@ -0,0 +1,210 @@
+package ollama
+
+import (
+ "bufio"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/logger"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ollamaChatStreamChunk struct {
+ Model string `json:"model"`
+ CreatedAt string `json:"created_at"`
+ // chat
+ Message *struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ Thinking json.RawMessage `json:"thinking"`
+ ToolCalls []struct {
+ Function struct {
+ Name string `json:"name"`
+ Arguments interface{} `json:"arguments"`
+ } `json:"function"`
+ } `json:"tool_calls"`
+ } `json:"message"`
+ // generate
+ Response string `json:"response"`
+ Done bool `json:"done"`
+ DoneReason string `json:"done_reason"`
+ TotalDuration int64 `json:"total_duration"`
+ LoadDuration int64 `json:"load_duration"`
+ PromptEvalCount int `json:"prompt_eval_count"`
+ EvalCount int `json:"eval_count"`
+ PromptEvalDuration int64 `json:"prompt_eval_duration"`
+ EvalDuration int64 `json:"eval_duration"`
+}
+
+func toUnix(ts string) int64 {
+ if ts == "" { return time.Now().Unix() }
+ // try time.RFC3339 or with nanoseconds
+ t, err := time.Parse(time.RFC3339Nano, ts)
+ if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
+ return t.Unix()
+}
+
+func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
+ defer service.CloseResponseBodyGracefully(resp)
+
+ helper.SetEventStreamHeaders(c)
+ scanner := bufio.NewScanner(resp.Body)
+ usage := &dto.Usage{}
+ var model = info.UpstreamModelName
+ var responseId = common.GetUUID()
+ var created = time.Now().Unix()
+ var toolCallIndex int
+ start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
+ if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
+
+ for scanner.Scan() {
+ line := scanner.Text()
+ line = strings.TrimSpace(line)
+ if line == "" { continue }
+ var chunk ollamaChatStreamChunk
+ if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+ logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
+ return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ if chunk.Model != "" { model = chunk.Model }
+ created = toUnix(chunk.CreatedAt)
+
+ if !chunk.Done {
+ // delta content
+ var content string
+ if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
+ delta := dto.ChatCompletionsStreamResponse{
+ Id: responseId,
+ Object: "chat.completion.chunk",
+ Created: created,
+ Model: model,
+ Choices: []dto.ChatCompletionsStreamResponseChoice{ {
+ Index: 0,
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
+ } },
+ }
+ if content != "" { delta.Choices[0].Delta.SetContentString(content) }
+ if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
+ raw := strings.TrimSpace(string(chunk.Message.Thinking))
+ if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
+ }
+ // tool calls
+ if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
+ delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
+ for _, tc := range chunk.Message.ToolCalls {
+ // arguments -> string
+ argBytes, _ := json.Marshal(tc.Function.Arguments)
+ toolId := fmt.Sprintf("call_%d", toolCallIndex)
+ tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
+ tr.SetIndex(toolCallIndex)
+ toolCallIndex++
+ delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
+ }
+ }
+ if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
+ continue
+ }
+ // done frame
+ // finalize once and break loop
+ usage.PromptTokens = chunk.PromptEvalCount
+ usage.CompletionTokens = chunk.EvalCount
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ finishReason := chunk.DoneReason
+ if finishReason == "" { finishReason = "stop" }
+ // emit stop delta
+ if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
+ if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
+ }
+ // emit usage frame
+ if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
+ if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
+ }
+ // send [DONE]
+ helper.Done(c)
+ break
+ }
+ if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
+ return usage, nil
+}
+
+// non-stream handler for chat/generate
+func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ body, err := io.ReadAll(resp.Body)
+ if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
+ service.CloseResponseBodyGracefully(resp)
+ raw := string(body)
+ if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
+
+ lines := strings.Split(raw, "\n")
+ var (
+ aggContent strings.Builder
+ reasoningBuilder strings.Builder
+ lastChunk ollamaChatStreamChunk
+ parsedAny bool
+ )
+ for _, ln := range lines {
+ ln = strings.TrimSpace(ln)
+ if ln == "" { continue }
+ var ck ollamaChatStreamChunk
+ if err := json.Unmarshal([]byte(ln), &ck); err != nil {
+ if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+ continue
+ }
+ parsedAny = true
+ lastChunk = ck
+ if ck.Message != nil && len(ck.Message.Thinking) > 0 {
+ raw := strings.TrimSpace(string(ck.Message.Thinking))
+ if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
+ }
+ if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+ }
+
+ if !parsedAny {
+ var single ollamaChatStreamChunk
+ if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+ lastChunk = single
+ if single.Message != nil {
+ if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
+ aggContent.WriteString(single.Message.Content)
+ } else { aggContent.WriteString(single.Response) }
+ }
+
+ model := lastChunk.Model
+ if model == "" { model = info.UpstreamModelName }
+ created := toUnix(lastChunk.CreatedAt)
+ usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
+ content := aggContent.String()
+ finishReason := lastChunk.DoneReason
+ if finishReason == "" { finishReason = "stop" }
+
+ msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
+ if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
+ full := dto.OpenAITextResponse{
+ Id: common.GetUUID(),
+ Model: model,
+ Object: "chat.completion",
+ Created: created,
+ Choices: []dto.OpenAITextResponseChoice{ {
+ Index: 0,
+ Message: msg,
+ FinishReason: finishReason,
+ } },
+ Usage: *usage,
+ }
+ out, _ := common.Marshal(full)
+ service.IOCopyBytesGracefully(c, resp, out)
+ return usage, nil
+}
+
+func contentPtr(s string) *string { if s=="" { return nil }; return &s }
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 4b13a7df..a88b6850 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -12,6 +12,7 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/logger"
+ "one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if common.DebugEnabled {
println("upstream response body:", string(responseBody))
}
+ // Unmarshal to simpleResponse
+ if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
+ // 尝试解析为 openrouter enterprise
+ var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
+ err = common.Unmarshal(responseBody, &enterpriseResponse)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ if enterpriseResponse.Success {
+ responseBody = enterpriseResponse.Data
+ } else {
+ logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
+ return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ }
+
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
+
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go
index e188889e..7b148f32 100644
--- a/relay/channel/openai/relay_responses.go
+++ b/relay/channel/openai/relay_responses.go
@@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
+ if responsesResponse.HasImageGenerationCall() {
+ c.Set("image_generation_call", true)
+ c.Set("image_generation_call_quality", responsesResponse.GetQuality())
+ c.Set("image_generation_call_size", responsesResponse.GetSize())
+ }
+
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
@@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
- if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
- if streamResponse.Response.Usage.InputTokens != 0 {
- usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+ if streamResponse.Response != nil {
+ if streamResponse.Response.Usage != nil {
+ if streamResponse.Response.Usage.InputTokens != 0 {
+ usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+ }
+ if streamResponse.Response.Usage.OutputTokens != 0 {
+ usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+ }
+ if streamResponse.Response.Usage.TotalTokens != 0 {
+ usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+ }
+ if streamResponse.Response.Usage.InputTokensDetails != nil {
+ usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+ }
}
- if streamResponse.Response.Usage.OutputTokens != 0 {
- usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
- }
- if streamResponse.Response.Usage.TotalTokens != 0 {
- usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
- }
- if streamResponse.Response.Usage.InputTokensDetails != nil {
- usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+ if streamResponse.Response.HasImageGenerationCall() {
+ c.Set("image_generation_call", true)
+ c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
+ c.Set("image_generation_call_size", streamResponse.Response.GetSize())
}
}
case "response.output_text.delta":
@@ -102,7 +115,11 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
if streamResponse.Item != nil {
switch streamResponse.Item.Type {
case dto.BuildInCallWebSearchCall:
- info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
+ if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
+ if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
+ webSearchTool.CallCount++
+ }
+ }
}
}
}
diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go
index 607f495b..a3249985 100644
--- a/relay/channel/openrouter/dto.go
+++ b/relay/channel/openrouter/dto.go
@@ -1,5 +1,7 @@
package openrouter
+import "encoding/json"
+
type RequestReasoning struct {
// One of the following (not both):
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
@@ -7,3 +9,8 @@ type RequestReasoning struct {
// Optional: Default is false. All models support this.
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
}
+
+type OpenRouterEnterpriseResponse struct {
+ Data json.RawMessage `json:"data"`
+ Success bool `json:"success"`
+}
diff --git a/relay/channel/submodel/adaptor.go b/relay/channel/submodel/adaptor.go
new file mode 100644
index 00000000..db58fe64
--- /dev/null
+++ b/relay/channel/submodel/adaptor.go
@@ -0,0 +1,86 @@
+package submodel
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ return nil, errors.New("submodel channel: endpoint not supported")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/submodel/constants.go b/relay/channel/submodel/constants.go
new file mode 100644
index 00000000..f5e1feb8
--- /dev/null
+++ b/relay/channel/submodel/constants.go
@@ -0,0 +1,16 @@
+package submodel
+
+var ModelList = []string{
+ "NousResearch/Hermes-4-405B-FP8",
+ "Qwen/Qwen3-235B-A22B-Thinking-2507",
+ "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8",
+ "Qwen/Qwen3-235B-A22B-Instruct-2507",
+ "zai-org/GLM-4.5-FP8",
+ "openai/gpt-oss-120b",
+ "deepseek-ai/DeepSeek-R1-0528",
+ "deepseek-ai/DeepSeek-R1",
+ "deepseek-ai/DeepSeek-V3-0324",
+ "deepseek-ai/DeepSeek-V3.1",
+}
+
+const ChannelName = "submodel"
\ No newline at end of file
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
index 955e592a..a2545a27 100644
--- a/relay/channel/task/jimeng/adaptor.go
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
- "one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
@@ -37,6 +36,7 @@ type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Seed int64 `json:"seed"`
AspectRatio string `json:"aspect_ratio"`
+ Frames int `json:"frames,omitempty"`
}
type responsePayload struct {
@@ -89,26 +89,14 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
- action := constant.TaskActionGenerate
- info.Action = action
-
- req := relaycommon.TaskSubmitReq{}
- if err := common.UnmarshalBodyReusable(c, &req); err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
- return
- }
- if strings.TrimSpace(req.Prompt) == "" {
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
- return
- }
-
- // Store into context for later usage
- c.Set("task_request", req)
- return nil
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if isNewAPIRelay(info.ApiKey) {
+ return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
+ }
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
}
@@ -116,7 +104,12 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
- return a.signRequest(req, a.accessKey, a.secretKey)
+ if isNewAPIRelay(info.ApiKey) {
+ req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+ } else {
+ return a.signRequest(req, a.accessKey, a.secretKey)
+ }
+ return nil
}
// BuildRequestBody converts request into Jimeng specific format.
@@ -176,6 +169,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
}
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
+ if isNewAPIRelay(key) {
+ uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
+ }
payload := map[string]string{
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
@@ -193,17 +189,20 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
- keyParts := strings.Split(key, "|")
- if len(keyParts) != 2 {
- return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
- }
- accessKey := strings.TrimSpace(keyParts[0])
- secretKey := strings.TrimSpace(keyParts[1])
+ if isNewAPIRelay(key) {
+ req.Header.Set("Authorization", "Bearer "+key)
+ } else {
+ keyParts := strings.Split(key, "|")
+ if len(keyParts) != 2 {
+ return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
- if err := a.signRequest(req, accessKey, secretKey); err != nil {
- return nil, errors.Wrap(err, "sign request failed")
+ if err := a.signRequest(req, accessKey, secretKey); err != nil {
+ return nil, errors.Wrap(err, "sign request failed")
+ }
}
-
return service.GetHttpClient().Do(req)
}
@@ -327,18 +326,23 @@ func hmacSHA256(key []byte, data []byte) []byte {
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
- ReqKey: "jimeng_vgfm_i2v_l20",
- Prompt: req.Prompt,
- AspectRatio: "16:9", // Default aspect ratio
- Seed: -1, // Default to random
+ ReqKey: req.Model,
+ Prompt: req.Prompt,
+ }
+
+ switch req.Duration {
+ case 10:
+ r.Frames = 241 // 24*10+1 = 241
+ default:
+ r.Frames = 121 // 24*5+1 = 121
}
// Handle one-of image_urls or binary_data_base64
- if req.Image != "" {
- if strings.HasPrefix(req.Image, "http") {
- r.ImageUrls = []string{req.Image}
+ if req.HasImage() {
+ if strings.HasPrefix(req.Images[0], "http") {
+ r.ImageUrls = req.Images
} else {
- r.BinaryDataBase64 = []string{req.Image}
+ r.BinaryDataBase64 = req.Images
}
}
metadata := req.Metadata
@@ -350,6 +354,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
+
+ // 即梦视频3.0 ReqKey转换
+ // https://www.volcengine.com/docs/85621/1792707
+ if strings.Contains(r.ReqKey, "jimeng_v30") {
+ if len(r.ImageUrls) > 1 {
+ // 多张图片:首尾帧生成
+ r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
+ } else if len(r.ImageUrls) == 1 {
+ // 单张图片:图生视频
+ r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
+ } else {
+ // 无图片:文生视频
+ r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
+ }
+ }
+
return &r, nil
}
@@ -378,3 +398,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
taskResult.Url = resTask.Data.VideoUrl
return &taskResult, nil
}
+
+func isNewAPIRelay(apiKey string) bool {
+ return strings.HasPrefix(apiKey, "sk-")
+}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
index 3d6da253..fec3396a 100644
--- a/relay/channel/task/kling/adaptor.go
+++ b/relay/channel/task/kling/adaptor.go
@@ -16,7 +16,6 @@ import (
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
- "one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
@@ -28,16 +27,6 @@ import (
// Request / Response structures
// ============================
-type SubmitReq struct {
- Prompt string `json:"prompt"`
- Model string `json:"model,omitempty"`
- Mode string `json:"mode,omitempty"`
- Image string `json:"image,omitempty"`
- Size string `json:"size,omitempty"`
- Duration int `json:"duration,omitempty"`
- Metadata map[string]interface{} `json:"metadata,omitempty"`
-}
-
type TrajectoryPoint struct {
X int `json:"x"`
Y int `json:"y"`
@@ -121,28 +110,18 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- // Accept only POST /v1/video/generations as "generate" action.
- action := constant.TaskActionGenerate
- info.Action = action
-
- var req SubmitReq
- if err := common.UnmarshalBodyReusable(c, &req); err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
- return
- }
- if strings.TrimSpace(req.Prompt) == "" {
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
- return
- }
-
- // Store into context for later usage
- c.Set("task_request", req)
- return nil
+ // Use the standard validation method for TaskSubmitReq
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
+
+ if isNewAPIRelay(info.ApiKey) {
+ return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
+ }
+
return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
@@ -166,7 +145,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if !exists {
return nil, fmt.Errorf("request not found in context")
}
- req := v.(SubmitReq)
+ req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
@@ -225,6 +204,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
}
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
+ if isNewAPIRelay(key) {
+ url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
+ }
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@@ -255,7 +237,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
@@ -330,8 +312,13 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
//}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
-
+ if isNewAPIRelay(apiKey) {
+ return apiKey, nil // new api relay
+ }
keyParts := strings.Split(apiKey, "|")
+ if len(keyParts) != 2 {
+ return "", errors.New("invalid api_key, required format is accessKey|secretKey")
+ }
accessKey := strings.TrimSpace(keyParts[0])
if len(keyParts) == 1 {
return accessKey, nil
@@ -378,3 +365,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
return taskInfo, nil
}
+
+func isNewAPIRelay(apiKey string) bool {
+ return strings.HasPrefix(apiKey, "sk-")
+}
diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go
new file mode 100644
index 00000000..4a236b2f
--- /dev/null
+++ b/relay/channel/task/vertex/adaptor.go
@@ -0,0 +1,355 @@
+package vertex
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/model"
+ "regexp"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ vertexcore "one-api/relay/channel/vertex"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type requestPayload struct {
+ Instances []map[string]any `json:"instances"`
+ Parameters map[string]any `json:"parameters,omitempty"`
+}
+
+type submitResponse struct {
+ Name string `json:"name"`
+}
+
+type operationVideo struct {
+ MimeType string `json:"mimeType"`
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
+ Encoding string `json:"encoding"`
+}
+
+type operationResponse struct {
+ Name string `json:"name"`
+ Done bool `json:"done"`
+ Response struct {
+ Type string `json:"@type"`
+ RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
+ Videos []operationVideo `json:"videos"`
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
+ Encoding string `json:"encoding"`
+ Video string `json:"video"`
+ } `json:"response"`
+ Error struct {
+ Message string `json:"message"`
+ } `json:"error"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ apiKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.ChannelBaseUrl
+ a.apiKey = info.ApiKey
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
+ // Use the standard validation method for TaskSubmitReq
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ adc := &vertexcore.Credentials{}
+ if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+ return "", fmt.Errorf("failed to decode credentials: %w", err)
+ }
+ modelName := info.OriginModelName
+ if modelName == "" {
+ modelName = "veo-3.0-generate-001"
+ }
+
+ region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
+ if strings.TrimSpace(region) == "" {
+ region = "global"
+ }
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
+ adc.ProjectID,
+ modelName,
+ ), nil
+ }
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
+ region,
+ adc.ProjectID,
+ region,
+ modelName,
+ ), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ adc := &vertexcore.Credentials{}
+ if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+ return fmt.Errorf("failed to decode credentials: %w", err)
+ }
+
+ token, err := vertexcore.AcquireAccessToken(*adc, "")
+ if err != nil {
+ return fmt.Errorf("failed to acquire access token: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("x-goog-user-project", adc.ProjectID)
+ return nil
+}
+
+// BuildRequestBody converts request into Vertex specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
+ v, ok := c.Get("task_request")
+ if !ok {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(relaycommon.TaskSubmitReq)
+
+ body := requestPayload{
+ Instances: []map[string]any{{"prompt": req.Prompt}},
+ Parameters: map[string]any{},
+ }
+ if req.Metadata != nil {
+ if v, ok := req.Metadata["storageUri"]; ok {
+ body.Parameters["storageUri"] = v
+ }
+ if v, ok := req.Metadata["sampleCount"]; ok {
+ body.Parameters["sampleCount"] = v
+ }
+ }
+ if _, ok := body.Parameters["sampleCount"]; !ok {
+ body.Parameters["sampleCount"] = 1
+ }
+
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ }
+ _ = resp.Body.Close()
+
+ var s submitResponse
+ if err := json.Unmarshal(responseBody, &s); err != nil {
+ return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
+ }
+ if strings.TrimSpace(s.Name) == "" {
+ return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
+ }
+ localID := encodeLocalTaskID(s.Name)
+ c.JSON(http.StatusOK, gin.H{"task_id": localID})
+ return localID, responseBody, nil
+}
+
+func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
+func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+ upstreamName, err := decodeLocalTaskID(taskID)
+ if err != nil {
+ return nil, fmt.Errorf("decode task_id failed: %w", err)
+ }
+ region := extractRegionFromOperationName(upstreamName)
+ if region == "" {
+ region = "us-central1"
+ }
+ project := extractProjectFromOperationName(upstreamName)
+ modelName := extractModelFromOperationName(upstreamName)
+ if project == "" || modelName == "" {
+ return nil, fmt.Errorf("cannot extract project/model from operation name")
+ }
+ var url string
+ if region == "global" {
+ url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
+ } else {
+ url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
+ }
+ payload := map[string]string{"operationName": upstreamName}
+ data, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+ adc := &vertexcore.Credentials{}
+ if err := json.Unmarshal([]byte(key), adc); err != nil {
+ return nil, fmt.Errorf("failed to decode credentials: %w", err)
+ }
+ token, err := vertexcore.AcquireAccessToken(*adc, "")
+ if err != nil {
+ return nil, fmt.Errorf("failed to acquire access token: %w", err)
+ }
+ req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("x-goog-user-project", adc.ProjectID)
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ var op operationResponse
+ if err := json.Unmarshal(respBody, &op); err != nil {
+ return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
+ }
+ ti := &relaycommon.TaskInfo{}
+ if op.Error.Message != "" {
+ ti.Status = model.TaskStatusFailure
+ ti.Reason = op.Error.Message
+ ti.Progress = "100%"
+ return ti, nil
+ }
+ if !op.Done {
+ ti.Status = model.TaskStatusInProgress
+ ti.Progress = "50%"
+ return ti, nil
+ }
+ ti.Status = model.TaskStatusSuccess
+ ti.Progress = "100%"
+ if len(op.Response.Videos) > 0 {
+ v0 := op.Response.Videos[0]
+ if v0.BytesBase64Encoded != "" {
+ mime := strings.TrimSpace(v0.MimeType)
+ if mime == "" {
+ enc := strings.TrimSpace(v0.Encoding)
+ if enc == "" {
+ enc = "mp4"
+ }
+ if strings.Contains(enc, "/") {
+ mime = enc
+ } else {
+ mime = "video/" + enc
+ }
+ }
+ ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
+ return ti, nil
+ }
+ }
+ if op.Response.BytesBase64Encoded != "" {
+ enc := strings.TrimSpace(op.Response.Encoding)
+ if enc == "" {
+ enc = "mp4"
+ }
+ mime := enc
+ if !strings.Contains(enc, "/") {
+ mime = "video/" + enc
+ }
+ ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
+ return ti, nil
+ }
+ if op.Response.Video != "" { // some variants use `video` as base64
+ enc := strings.TrimSpace(op.Response.Encoding)
+ if enc == "" {
+ enc = "mp4"
+ }
+ mime := enc
+ if !strings.Contains(enc, "/") {
+ mime = "video/" + enc
+ }
+ ti.Url = "data:" + mime + ";base64," + op.Response.Video
+ return ti, nil
+ }
+ return ti, nil
+}
+
+// ============================
+// helpers
+// ============================
+
+func encodeLocalTaskID(name string) string {
+ return base64.RawURLEncoding.EncodeToString([]byte(name))
+}
+
+func decodeLocalTaskID(local string) (string, error) {
+ b, err := base64.RawURLEncoding.DecodeString(local)
+ if err != nil {
+ return "", err
+ }
+ return string(b), nil
+}
+
+var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
+
+func extractRegionFromOperationName(name string) string {
+ m := regionRe.FindStringSubmatch(name)
+ if len(m) == 2 {
+ return m[1]
+ }
+ return ""
+}
+
+var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
+
+func extractModelFromOperationName(name string) string {
+ m := modelRe.FindStringSubmatch(name)
+ if len(m) == 2 {
+ return m[1]
+ }
+ idx := strings.Index(name, "models/")
+ if idx >= 0 {
+ s := name[idx+len("models/"):]
+ if p := strings.Index(s, "/operations/"); p > 0 {
+ return s[:p]
+ }
+ }
+ return ""
+}
+
+var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
+
+func extractProjectFromOperationName(name string) string {
+ m := projectRe.FindStringSubmatch(name)
+ if len(m) == 2 {
+ return m[1]
+ }
+ return ""
+}
diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go
index c82c1c0e..358aef58 100644
--- a/relay/channel/task/vidu/adaptor.go
+++ b/relay/channel/task/vidu/adaptor.go
@@ -23,16 +23,6 @@ import (
// Request / Response structures
// ============================
-type SubmitReq struct {
- Prompt string `json:"prompt"`
- Model string `json:"model,omitempty"`
- Mode string `json:"mode,omitempty"`
- Image string `json:"image,omitempty"`
- Size string `json:"size,omitempty"`
- Duration int `json:"duration,omitempty"`
- Metadata map[string]interface{} `json:"metadata,omitempty"`
-}
-
type requestPayload struct {
Model string `json:"model"`
Images []string `json:"images"`
@@ -90,23 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
- var req SubmitReq
- if err := c.ShouldBindJSON(&req); err != nil {
- return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
- }
-
- if req.Prompt == "" {
- return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
- }
-
- if req.Image != "" {
- info.Action = constant.TaskActionGenerate
- } else {
- info.Action = constant.TaskActionTextGenerate
- }
-
- c.Set("task_request", req)
- return nil
+ return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -114,7 +88,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
if !exists {
return nil, fmt.Errorf("request not found in context")
}
- req := v.(SubmitReq)
+ req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
@@ -137,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
switch info.Action {
case constant.TaskActionGenerate:
path = "/img2video"
+ case constant.TaskActionFirstTailGenerate:
+ path = "/start-end2video"
+ case constant.TaskActionReferenceGenerate:
+ path = "/reference2video"
default:
path = "/text2video"
}
@@ -211,15 +189,10 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
-func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
- var images []string
- if req.Image != "" {
- images = []string{req.Image}
- }
-
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
Model: defaultString(req.Model, "viduq1"),
- Images: images,
+ Images: req.Images,
Prompt: req.Prompt,
Duration: defaultInt(req.Duration, 5),
Resolution: defaultString(req.Size, "1080p"),
diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go
index 0b6b2674..c4781813 100644
--- a/relay/channel/vertex/adaptor.go
+++ b/relay/channel/vertex/adaptor.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
+ "one-api/common"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
@@ -36,6 +37,7 @@ var claudeModelMap = map[string]string{
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
"claude-opus-4-20250514": "claude-opus-4@20250514",
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -80,16 +82,91 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
}
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- adc := &Credentials{}
- if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
- return "", fmt.Errorf("failed to decode credentials file: %w", err)
- }
+func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
- a.AccountCredentials = *adc
+ if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
+ adc := &Credentials{}
+ if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
+ return "", fmt.Errorf("failed to decode credentials file: %w", err)
+ }
+ a.AccountCredentials = *adc
+
+ if a.RequestMode == RequestModeGemini {
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
+ adc.ProjectID,
+ modelName,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ modelName,
+ suffix,
+ ), nil
+ }
+ } else if a.RequestMode == RequestModeClaude {
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
+ adc.ProjectID,
+ modelName,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ modelName,
+ suffix,
+ ), nil
+ }
+ } else if a.RequestMode == RequestModeLlama {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
+ region,
+ adc.ProjectID,
+ region,
+ ), nil
+ }
+ } else {
+ var keyPrefix string
+ if strings.HasSuffix(suffix, "?alt=sse") {
+ keyPrefix = "&"
+ } else {
+ keyPrefix = "?"
+ }
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
+ modelName,
+ suffix,
+ keyPrefix,
+ info.ApiKey,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
+ region,
+ modelName,
+ suffix,
+ keyPrefix,
+ info.ApiKey,
+ ), nil
+ }
+ }
+ return "", errors.New("unsupported request mode")
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
suffix := ""
if a.RequestMode == RequestModeGemini {
-
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-
格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
@@ -111,24 +188,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict"
}
-
- if region == "global" {
- return fmt.Sprintf(
- "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
- adc.ProjectID,
- info.UpstreamModelName,
- suffix,
- ), nil
- } else {
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
- region,
- adc.ProjectID,
- region,
- info.UpstreamModelName,
- suffix,
- ), nil
- }
+ return a.getRequestUrl(info, info.UpstreamModelName, suffix)
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
@@ -139,41 +199,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v
}
- if region == "global" {
- return fmt.Sprintf(
- "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
- adc.ProjectID,
- model,
- suffix,
- ), nil
- } else {
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
- region,
- adc.ProjectID,
- region,
- model,
- suffix,
- ), nil
- }
+ return a.getRequestUrl(info, model, suffix)
} else if a.RequestMode == RequestModeLlama {
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
- region,
- adc.ProjectID,
- region,
- ), nil
+ return a.getRequestUrl(info, "", "")
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
- accessToken, err := getAccessToken(a, info)
- if err != nil {
- return err
+ if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
+ accessToken, err := getAccessToken(a, info)
+ if err != nil {
+ return err
+ }
+ req.Set("Authorization", "Bearer "+accessToken)
+ }
+ if a.AccountCredentials.ProjectID != "" {
+ req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
}
- req.Set("Authorization", "Bearer "+accessToken)
return nil
}
diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go
index 5ed87665..f0b84906 100644
--- a/relay/channel/vertex/relay-vertex.go
+++ b/relay/channel/vertex/relay-vertex.go
@@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
- return m["default"].(string)
+ if v, ok := m["default"]; ok {
+ return v.(string)
+ }
+ return "global"
}
}
return other
diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go
index 9a4650d9..f90d5454 100644
--- a/relay/channel/vertex/service_account.go
+++ b/relay/channel/vertex/service_account.go
@@ -6,14 +6,15 @@ import (
"encoding/json"
"encoding/pem"
"errors"
- "github.com/bytedance/gopkg/cache/asynccache"
- "github.com/golang-jwt/jwt"
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
+ "github.com/bytedance/gopkg/cache/asynccache"
+ "github.com/golang-jwt/jwt"
+
"fmt"
"time"
)
@@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
return "", fmt.Errorf("failed to get access token: %v", result)
}
+
+func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
+ signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
+ if err != nil {
+ return "", fmt.Errorf("failed to create signed JWT: %w", err)
+ }
+ return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
+}
+
+func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
+ authURL := "https://www.googleapis.com/oauth2/v4/token"
+ data := url.Values{}
+ data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
+ data.Set("assertion", signedJWT)
+
+ var client *http.Client
+ var err error
+ if proxy != "" {
+ client, err = service.NewProxyHttpClient(proxy)
+ if err != nil {
+ return "", fmt.Errorf("new proxy http client failed: %w", err)
+ }
+ } else {
+ client = service.GetHttpClient()
+ }
+
+ resp, err := client.PostForm(authURL, data)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ var result map[string]interface{}
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return "", err
+ }
+
+ if accessToken, ok := result["access_token"].(string); ok {
+ return accessToken, nil
+ }
+ return "", fmt.Errorf("failed to get access token: %v", result)
+}
diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go
index 0af019da..234ab4c9 100644
--- a/relay/channel/volcengine/adaptor.go
+++ b/relay/channel/volcengine/adaptor.go
@@ -9,6 +9,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
+ channelconstant "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
@@ -41,6 +42,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
+ case constant.RelayModeImagesGenerations:
+ return request, nil
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
@@ -186,21 +189,35 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- switch info.RelayMode {
- case constant.RelayModeChatCompletions:
+ // 支持自定义域名,如果未设置则使用默认域名
+ baseUrl := info.ChannelBaseUrl
+ if baseUrl == "" {
+ baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
+ }
+
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
- return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
+ return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
}
- return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
- case constant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
- case constant.RelayModeImagesGenerations:
- return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
- case constant.RelayModeImagesEdits:
- return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
- case constant.RelayModeRerank:
- return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
+ return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
default:
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ if strings.HasPrefix(info.UpstreamModelName, "bot") {
+ return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
+ }
+ return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
+ case constant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
+ case constant.RelayModeImagesGenerations:
+ return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
+ case constant.RelayModeImagesEdits:
+ return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
+ case constant.RelayModeRerank:
+ return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
+ default:
+ }
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go
index 30cc902e..87a12b27 100644
--- a/relay/channel/volcengine/constants.go
+++ b/relay/channel/volcengine/constants.go
@@ -8,6 +8,12 @@ var ModelList = []string{
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
+ "doubao-seedream-4-0-250828",
+ "seedream-4-0-250828",
+ "doubao-seedance-1-0-pro-250528",
+ "seedance-1-0-pro-250528",
+ "doubao-seed-1-6-thinking-250715",
+ "seed-1-6-thinking-250715",
}
var ChannelName = "volcengine"
diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go
index 9d5c190f..9503d5d3 100644
--- a/relay/channel/xunfei/relay-xunfei.go
+++ b/relay/channel/xunfei/relay-xunfei.go
@@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
return nil, nil, err
}
- defer func() {
- conn.Close()
- }()
-
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
@@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
+ defer func() {
+ conn.Close()
+ }()
for {
_, msg, err := conn.ReadMessage()
if err != nil {
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index 59c052f6..3a739785 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
info.UpstreamModelName = request.Model
}
+ if info.ChannelSetting.SystemPrompt != "" {
+ if request.System == nil {
+ request.SetStringSystem(info.ChannelSetting.SystemPrompt)
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ if request.IsStringSystem() {
+ existing := strings.TrimSpace(request.GetStringSystem())
+ if existing == "" {
+ request.SetStringSystem(info.ChannelSetting.SystemPrompt)
+ } else {
+ request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
+ }
+ } else {
+ systemContents := request.ParseSystem()
+ newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
+ newSystem.SetText(info.ChannelSetting.SystemPrompt)
+ if len(systemContents) == 0 {
+ request.System = []dto.ClaudeMediaMessage{newSystem}
+ } else {
+ request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
+ }
+ }
+ }
+ }
+
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
@@ -86,6 +112,12 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ // remove disabled fields for Claude API
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
@@ -111,7 +143,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index da572c07..35f8ad19 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -105,7 +105,8 @@ type RelayInfo struct {
UserQuota int
RelayFormat types.RelayFormat
SendResponseCount int
- FinalPreConsumedQuota int // 最终预消耗的配额
+ FinalPreConsumedQuota int // 最终预消耗的配额
+ IsClaudeBetaQuery bool // /v1/messages?beta=true
PriceData types.PriceData
@@ -279,6 +280,9 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
}
+ if c.Query("beta") == "true" {
+ info.IsClaudeBetaQuery = true
+ }
return info
}
@@ -481,11 +485,20 @@ type TaskSubmitReq struct {
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
+ Images []string `json:"images,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
+func (t TaskSubmitReq) GetPrompt() string {
+ return t.Prompt
+}
+
+func (t TaskSubmitReq) HasImage() bool {
+ return len(t.Images) > 0
+}
+
type TaskInfo struct {
Code int `json:"code"`
TaskID string `json:"task_id"`
@@ -494,3 +507,43 @@ type TaskInfo struct {
Url string `json:"url,omitempty"`
Progress string `json:"progress,omitempty"`
}
+
+// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
+// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
+// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
+// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
+func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
+ var data map[string]interface{}
+ if err := common.Unmarshal(jsonData, &data); err != nil {
+ common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
+ return jsonData, nil
+ }
+
+ // 默认移除 service_tier,除非明确允许(避免额外计费风险)
+ if !channelOtherSettings.AllowServiceTier {
+ if _, exists := data["service_tier"]; exists {
+ delete(data, "service_tier")
+ }
+ }
+
+ // 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
+ if channelOtherSettings.DisableStore {
+ if _, exists := data["store"]; exists {
+ delete(data, "store")
+ }
+ }
+
+ // 默认移除 safety_identifier,除非明确允许(保护用户隐私,避免向 OpenAI 报告用户信息)
+ if !channelOtherSettings.AllowSafetyIdentifier {
+ if _, exists := data["safety_identifier"]; exists {
+ delete(data, "safety_identifier")
+ }
+ }
+
+ jsonDataAfter, err := common.Marshal(data)
+ if err != nil {
+ common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
+ return jsonData, nil
+ }
+ return jsonDataAfter, nil
+}
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
index 3d5efcb6..3a721b47 100644
--- a/relay/common/relay_utils.go
+++ b/relay/common/relay_utils.go
@@ -2,12 +2,23 @@ package common
import (
"fmt"
+ "net/http"
+ "one-api/common"
"one-api/constant"
+ "one-api/dto"
"strings"
"github.com/gin-gonic/gin"
)
+type HasPrompt interface {
+ GetPrompt() string
+}
+
+type HasImage interface {
+ HasImage() bool
+}
+
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
@@ -30,3 +41,56 @@ func GetAPIVersion(c *gin.Context) string {
}
return apiVersion
}
+
+func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
+ return &dto.TaskError{
+ Code: code,
+ Message: err.Error(),
+ StatusCode: statusCode,
+ LocalError: localError,
+ Error: err,
+ }
+}
+
+func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
+ info.Action = action
+ c.Set("task_request", requestObj)
+}
+
+func validatePrompt(prompt string) *dto.TaskError {
+ if strings.TrimSpace(prompt) == "" {
+ return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
+ }
+ return nil
+}
+
+func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
+ var req TaskSubmitReq
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
+ }
+
+ if taskErr := validatePrompt(req.Prompt); taskErr != nil {
+ return taskErr
+ }
+
+ if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
+ // 兼容单图上传
+ req.Images = []string{req.Image}
+ }
+
+ if req.HasImage() {
+ action = constant.TaskActionGenerate
+ if info.ChannelType == constant.ChannelTypeVidu {
+ // vidu 增加 首尾帧生视频和参考图生视频
+ if len(req.Images) == 2 {
+ action = constant.TaskActionFirstTailGenerate
+ } else if len(req.Images) > 2 {
+ action = constant.TaskActionReferenceGenerate
+ }
+ }
+ }
+
+ storeTaskRequest(c, info, action, req)
+ return nil
+}
diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go
index a3c6ace6..a3ddf6d4 100644
--- a/relay/compatible_handler.go
+++ b/relay/compatible_handler.go
@@ -90,41 +90,43 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if info.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
- request := convertedRequest.(*dto.GeneralOpenAIRequest)
- containSystemPrompt := false
- for _, message := range request.Messages {
- if message.Role == request.GetSystemRoleName() {
- containSystemPrompt = true
- break
- }
- }
- if !containSystemPrompt {
- // 如果没有系统提示,则添加系统提示
- systemMessage := dto.Message{
- Role: request.GetSystemRoleName(),
- Content: info.ChannelSetting.SystemPrompt,
- }
- request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
- } else if info.ChannelSetting.SystemPromptOverride {
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
- // 如果有系统提示,且允许覆盖,则拼接到前面
- for i, message := range request.Messages {
+ request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
+ if ok {
+ containSystemPrompt := false
+ for _, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
- if message.IsStringContent() {
- request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
- } else {
- contents := message.ParseContent()
- contents = append([]dto.MediaContent{
- {
- Type: dto.ContentTypeText,
- Text: info.ChannelSetting.SystemPrompt,
- },
- }, contents...)
- request.Messages[i].Content = contents
- }
+ containSystemPrompt = true
break
}
}
+ if !containSystemPrompt {
+ // 如果没有系统提示,则添加系统提示
+ systemMessage := dto.Message{
+ Role: request.GetSystemRoleName(),
+ Content: info.ChannelSetting.SystemPrompt,
+ }
+ request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ // 如果有系统提示,且允许覆盖,则拼接到前面
+ for i, message := range request.Messages {
+ if message.Role == request.GetSystemRoleName() {
+ if message.IsStringContent() {
+ request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
+ } else {
+ contents := message.ParseContent()
+ contents = append([]dto.MediaContent{
+ {
+ Type: dto.ContentTypeText,
+ Text: info.ChannelSetting.SystemPrompt,
+ },
+ }, contents...)
+ request.Messages[i].Content = contents
+ }
+ break
+ }
+ }
+ }
}
}
@@ -133,6 +135,12 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry())
}
+ // remove disabled fields for OpenAI API
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
@@ -158,7 +166,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
- newApiErr := service.RelayErrorHandler(httpResp, false)
+ newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr
@@ -195,6 +203,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
imageTokens := usage.PromptTokensDetails.ImageTokens
audioTokens := usage.PromptTokensDetails.AudioTokens
completionTokens := usage.CompletionTokens
+ cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
+
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
@@ -204,6 +214,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
+ cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -211,12 +222,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
dImageTokens := decimal.NewFromInt(int64(imageTokens))
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
+ dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
dCompletionRatio := decimal.NewFromFloat(completionRatio)
dCacheRatio := decimal.NewFromFloat(cacheRatio)
dImageRatio := decimal.NewFromFloat(imageRatio)
dModelRatio := decimal.NewFromFloat(modelRatio)
dGroupRatio := decimal.NewFromFloat(groupRatio)
dModelPrice := decimal.NewFromFloat(modelPrice)
+ dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
ratio := dModelRatio.Mul(dGroupRatio)
@@ -271,6 +284,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
fileSearchTool.CallCount, dFileSearchQuota.String())
}
}
+ var dImageGenerationCallQuota decimal.Decimal
+ var imageGenerationCallPrice float64
+ if ctx.GetBool("image_generation_call") {
+ imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
+ dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
+ }
var quotaCalculateDecimal decimal.Decimal
@@ -284,6 +304,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
baseTokens = baseTokens.Sub(dCacheTokens)
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
+ var dCachedCreationTokensWithRatio decimal.Decimal
+ if !dCachedCreationTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dCachedCreationTokens)
+ dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
+ }
// 减去 image tokens
var imageTokensWithRatio decimal.Decimal
@@ -302,7 +327,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
}
}
- promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
+ promptQuota := baseTokens.Add(cachedTokensWithRatio).
+ Add(imageTokensWithRatio).
+ Add(dCachedCreationTokensWithRatio)
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
@@ -314,22 +341,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
} else {
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
}
- var dGeminiImageOutputQuota decimal.Decimal
- var imageOutputPrice float64
- if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
- imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName)
- if imageOutputPrice > 0 {
- dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens")))
- dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- }
- }
// 添加 responses tools call 调用的配额
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
// 添加 audio input 独立计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
- // 添加 Gemini image output 计费
- quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota)
+ // 添加 image generation call 计费
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
@@ -395,6 +413,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
other["image_ratio"] = imageRatio
other["image_output"] = imageTokens
}
+ if cachedCreationTokens != 0 {
+ other["cache_creation_tokens"] = cachedCreationTokens
+ other["cache_creation_ratio"] = cachedCreationRatio
+ }
if !dWebSearchQuota.IsZero() {
if relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
@@ -424,9 +446,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice
}
- if !dGeminiImageOutputQuota.IsZero() {
- other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens")
- other["image_output_price"] = imageOutputPrice
+ if !dImageGenerationCallQuota.IsZero() {
+ other["image_generation_call"] = true
+ other["image_generation_call_price"] = imageGenerationCallPrice
}
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go
index 26dcf971..3d8962bb 100644
--- a/relay/embedding_handler.go
+++ b/relay/embedding_handler.go
@@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
index 460fd2f5..1410da60 100644
--- a/relay/gemini_handler.go
+++ b/relay/gemini_handler.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/gemini"
@@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
adaptor.Init(info)
+ if info.ChannelSetting.SystemPrompt != "" {
+ if request.SystemInstructions == nil {
+ request.SystemInstructions = &dto.GeminiChatContent{
+ Parts: []dto.GeminiPart{
+ {Text: info.ChannelSetting.SystemPrompt},
+ },
+ }
+ } else if len(request.SystemInstructions.Parts) == 0 {
+ request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ merged := false
+ for i := range request.SystemInstructions.Parts {
+ if request.SystemInstructions.Parts[i].Text == "" {
+ continue
+ }
+ request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
+ merged = true
+ break
+ }
+ if !merged {
+ request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
+ }
+ }
+ }
+
// Clean up empty system instruction
if request.SystemInstructions != nil {
hasContent := false
@@ -152,7 +179,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
@@ -249,7 +276,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
diff --git a/relay/helper/price.go b/relay/helper/price.go
index fdc5b66d..c23c068b 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -52,6 +52,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var cacheRatio float64
var imageRatio float64
var cacheCreationRatio float64
+ var audioRatio float64
+ var audioCompletionRatio float64
if !usePrice {
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
if meta.MaxTokens != 0 {
@@ -73,6 +75,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
+ audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
+ audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
@@ -90,6 +94,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
+ AudioRatio: audioRatio,
+ AudioCompletionRatio: audioCompletionRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}
diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go
index 4d1c1f9b..f4a290ec 100644
--- a/relay/helper/valid_request.go
+++ b/relay/helper/valid_request.go
@@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
case types.RelayFormatOpenAI:
request, err = GetAndValidateTextRequest(c, relayMode)
case types.RelayFormatGemini:
- request, err = GetAndValidateGeminiRequest(c)
+ if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
+ request, err = GetAndValidateGeminiEmbeddingRequest(c)
+ } else {
+ request, err = GetAndValidateGeminiRequest(c)
+ }
case types.RelayFormatClaude:
request, err = GetAndValidateClaudeRequest(c)
case types.RelayFormatOpenAIResponses:
@@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA
}
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
-
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
@@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
return request, nil
}
+
+func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
+ request := &dto.GeminiEmbeddingRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ return request, nil
+}
diff --git a/relay/image_handler.go b/relay/image_handler.go
index 14a7103c..9c873d47 100644
--- a/relay/image_handler.go
+++ b/relay/image_handler.go
@@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
@@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
var logContent string
if len(request.Size) > 0 {
- logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
+ logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
}
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
diff --git a/relay/mjproxy_handler.go b/relay/mjproxy_handler.go
index 7c52cb6b..ec8dfc6b 100644
--- a/relay/mjproxy_handler.go
+++ b/relay/mjproxy_handler.go
@@ -16,6 +16,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
+ "one-api/setting/system_setting"
"strconv"
"strings"
"time"
@@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
- midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
+ midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 1ee85986..406074c5 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -1,7 +1,6 @@
package relay
import (
- "github.com/gin-gonic/gin"
"one-api/constant"
"one-api/relay/channel"
"one-api/relay/channel/ali"
@@ -28,6 +27,7 @@ import (
taskjimeng "one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
+ taskvertex "one-api/relay/channel/task/vertex"
taskVidu "one-api/relay/channel/task/vidu"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
@@ -37,6 +37,8 @@ import (
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
"strconv"
+ "one-api/relay/channel/submodel"
+ "github.com/gin-gonic/gin"
)
func GetAdaptor(apiType int) channel.Adaptor {
@@ -101,6 +103,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &jimeng.Adaptor{}
case constant.APITypeMoonshot:
return &moonshot.Adaptor{} // Moonshot uses Claude API
+ case constant.APITypeSubmodel:
+ return &submodel.Adaptor{}
}
return nil
}
@@ -126,6 +130,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
return &kling.TaskAdaptor{}
case constant.ChannelTypeJimeng:
return &taskjimeng.TaskAdaptor{}
+ case constant.ChannelTypeVertexAi:
+ return &taskvertex.TaskAdaptor{}
case constant.ChannelTypeVidu:
return &taskVidu.TaskAdaptor{}
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 0754e023..9cb8cd5c 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -15,6 +15,8 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting/ratio_setting"
+ "strconv"
+ "strings"
"github.com/gin-gonic/gin"
)
@@ -33,6 +35,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
platform = GetTaskPlatform(c)
}
+ info.InitChannelMeta(c)
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
@@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
if taskErr != nil {
return taskErr
}
+ if len(respBody) == 0 {
+ respBody = []byte("{\"code\":\"success\",\"data\":null}")
+ }
c.Writer.Header().Set("Content-Type", "application/json")
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
@@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return
}
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
+ func() {
+ channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
+ if err2 != nil {
+ return
+ }
+ if channelModel.Type != constant.ChannelTypeVertexAi {
+ return
+ }
+ baseURL := constant.ChannelBaseURLs[channelModel.Type]
+ if channelModel.GetBaseURL() != "" {
+ baseURL = channelModel.GetBaseURL()
+ }
+ adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
+ if adaptor == nil {
+ return
+ }
+ resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
+ "task_id": originTask.TaskID,
+ "action": originTask.Action,
+ })
+ if err2 != nil || resp == nil {
+ return
+ }
+ defer resp.Body.Close()
+ body, err2 := io.ReadAll(resp.Body)
+ if err2 != nil {
+ return
+ }
+ ti, err2 := adaptor.ParseTaskResult(body)
+ if err2 == nil && ti != nil {
+ if ti.Status != "" {
+ originTask.Status = model.TaskStatus(ti.Status)
+ }
+ if ti.Progress != "" {
+ originTask.Progress = ti.Progress
+ }
+ if ti.Url != "" {
+ originTask.FailReason = ti.Url
+ }
+ _ = originTask.Update()
+ var raw map[string]any
+ _ = json.Unmarshal(body, &raw)
+ format := "mp4"
+ if respObj, ok := raw["response"].(map[string]any); ok {
+ if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
+ if v0, ok := vids[0].(map[string]any); ok {
+ if mt, ok := v0["mimeType"].(string); ok && mt != "" {
+ if strings.Contains(mt, "mp4") {
+ format = "mp4"
+ } else {
+ format = mt
+ }
+ }
+ }
+ }
+ }
+ status := "processing"
+ switch originTask.Status {
+ case model.TaskStatusSuccess:
+ status = "succeeded"
+ case model.TaskStatusFailure:
+ status = "failed"
+ case model.TaskStatusQueued, model.TaskStatusSubmitted:
+ status = "queued"
+ }
+ out := map[string]any{
+ "error": nil,
+ "format": format,
+ "metadata": nil,
+ "status": status,
+ "task_id": originTask.TaskID,
+ "url": originTask.FailReason,
+ }
+ respBody, _ = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: out,
+ })
+ }
+ }()
+
+ if len(respBody) == 0 {
+ respBody, err = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: TaskModel2Dto(originTask),
+ })
+ }
return
}
diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go
index fa3c7bbb..46d2e25f 100644
--- a/relay/rerank_handler.go
+++ b/relay/rerank_handler.go
@@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
diff --git a/relay/responses_handler.go b/relay/responses_handler.go
index f5f624c9..6958f96e 100644
--- a/relay/responses_handler.go
+++ b/relay/responses_handler.go
@@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
}
adaptor.Init(info)
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
@@ -56,6 +56,13 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+
+ // remove disabled fields for OpenAI Responses API
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
@@ -82,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
diff --git a/router/api-router.go b/router/api-router.go
index 77385738..d2961591 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -40,11 +40,17 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
+ // Universal secure verification routes
+ apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
+ apiRouter.GET("/verify/status", middleware.UserAuth(), controller.GetVerificationStatus)
+
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
+ userRoute.POST("/passkey/login/begin", middleware.CriticalRateLimit(), controller.PasskeyLoginBegin)
+ userRoute.POST("/passkey/login/finish", middleware.CriticalRateLimit(), controller.PasskeyLoginFinish)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
@@ -59,7 +65,14 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.PUT("/self", controller.UpdateSelf)
selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken)
+ selfRoute.GET("/passkey", controller.PasskeyStatus)
+ selfRoute.POST("/passkey/register/begin", controller.PasskeyRegisterBegin)
+ selfRoute.POST("/passkey/register/finish", controller.PasskeyRegisterFinish)
+ selfRoute.POST("/passkey/verify/begin", controller.PasskeyVerifyBegin)
+ selfRoute.POST("/passkey/verify/finish", controller.PasskeyVerifyFinish)
+ selfRoute.DELETE("/passkey", controller.PasskeyDelete)
selfRoute.GET("/aff", controller.GetAffCode)
+ selfRoute.GET("/topup/info", controller.GetTopUpInfo)
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
@@ -86,6 +99,7 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser)
+ adminRoute.DELETE("/:id/reset_passkey", controller.AdminResetPasskey)
// Admin 2FA routes
adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
@@ -114,7 +128,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/models", controller.ChannelListModels)
channelRoute.GET("/models_enabled", controller.EnabledListModels)
channelRoute.GET("/:id", controller.GetChannel)
- channelRoute.POST("/:id/key", middleware.CriticalRateLimit(), middleware.DisableCache(), controller.GetChannelKey)
+ channelRoute.POST("/:id/key", middleware.CriticalRateLimit(), middleware.DisableCache(), middleware.SecureVerificationRequired(), controller.GetChannelKey)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
diff --git a/service/cf_worker.go b/service/cf_worker.go
deleted file mode 100644
index 4a7b4376..00000000
--- a/service/cf_worker.go
+++ /dev/null
@@ -1,57 +0,0 @@
-package service
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/setting"
- "strings"
-)
-
-// WorkerRequest Worker请求的数据结构
-type WorkerRequest struct {
- URL string `json:"url"`
- Key string `json:"key"`
- Method string `json:"method,omitempty"`
- Headers map[string]string `json:"headers,omitempty"`
- Body json.RawMessage `json:"body,omitempty"`
-}
-
-// DoWorkerRequest 通过Worker发送请求
-func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
- if !setting.EnableWorker() {
- return nil, fmt.Errorf("worker not enabled")
- }
- if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
- return nil, fmt.Errorf("only support https url")
- }
-
- workerUrl := setting.WorkerUrl
- if !strings.HasSuffix(workerUrl, "/") {
- workerUrl += "/"
- }
-
- // 序列化worker请求数据
- workerPayload, err := json.Marshal(req)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
- }
-
- return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
-}
-
-func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
- if setting.EnableWorker() {
- common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
- req := &WorkerRequest{
- URL: originUrl,
- Key: setting.WorkerValidKey,
- }
- return DoWorkerRequest(req)
- } else {
- common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
- return http.Get(originUrl)
- }
-}
diff --git a/service/download.go b/service/download.go
new file mode 100644
index 00000000..036c43af
--- /dev/null
+++ b/service/download.go
@@ -0,0 +1,69 @@
+package service
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/setting/system_setting"
+ "strings"
+)
+
+// WorkerRequest Worker请求的数据结构
+type WorkerRequest struct {
+ URL string `json:"url"`
+ Key string `json:"key"`
+ Method string `json:"method,omitempty"`
+ Headers map[string]string `json:"headers,omitempty"`
+ Body json.RawMessage `json:"body,omitempty"`
+}
+
+// DoWorkerRequest 通过Worker发送请求
+func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
+ if !system_setting.EnableWorker() {
+ return nil, fmt.Errorf("worker not enabled")
+ }
+ if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
+ return nil, fmt.Errorf("only support https url")
+ }
+
+ // SSRF防护:验证请求URL
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return nil, fmt.Errorf("request reject: %v", err)
+ }
+
+ workerUrl := system_setting.WorkerUrl
+ if !strings.HasSuffix(workerUrl, "/") {
+ workerUrl += "/"
+ }
+
+ // 序列化worker请求数据
+ workerPayload, err := json.Marshal(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
+ }
+
+ return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
+}
+
+func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
+ if system_setting.EnableWorker() {
+ common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
+ req := &WorkerRequest{
+ URL: originUrl,
+ Key: system_setting.WorkerValidKey,
+ }
+ return DoWorkerRequest(req)
+ } else {
+ // SSRF防护:验证请求URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return nil, fmt.Errorf("request reject: %v", err)
+ }
+
+ common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
+ return http.Get(originUrl)
+ }
+}
diff --git a/service/epay.go b/service/epay.go
index a8259d21..48b84dd5 100644
--- a/service/epay.go
+++ b/service/epay.go
@@ -1,12 +1,13 @@
package service
import (
- "one-api/setting"
+ "one-api/setting/operation_setting"
+ "one-api/setting/system_setting"
)
func GetCallbackAddress() string {
- if setting.CustomCallbackAddress == "" {
- return setting.ServerAddress
+ if operation_setting.CustomCallbackAddress == "" {
+ return system_setting.ServerAddress
}
- return setting.CustomCallbackAddress
+ return operation_setting.CustomCallbackAddress
}
diff --git a/service/error.go b/service/error.go
index ef5cbbde..5c3bddd6 100644
--- a/service/error.go
+++ b/service/error.go
@@ -1,12 +1,14 @@
package service
import (
+ "context"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/types"
"strconv"
"strings"
@@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
return claudeErr
}
-func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
+func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
responseBody, err := io.ReadAll(resp.Body)
@@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
} else {
if common.DebugEnabled {
- println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
+ logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
}
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
}
diff --git a/service/http_client.go b/service/http_client.go
index b191ddd7..c1d6880c 100644
--- a/service/http_client.go
+++ b/service/http_client.go
@@ -7,12 +7,17 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "sync"
"time"
"golang.org/x/net/proxy"
)
-var httpClient *http.Client
+var (
+ httpClient *http.Client
+ proxyClientLock sync.Mutex
+ proxyClients = make(map[string]*http.Client)
+)
func InitHttpClient() {
if common.RelayTimeout == 0 {
@@ -28,12 +33,31 @@ func GetHttpClient() *http.Client {
return httpClient
}
+// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
+func ResetProxyClientCache() {
+ proxyClientLock.Lock()
+ defer proxyClientLock.Unlock()
+ for _, client := range proxyClients {
+ if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
+ transport.CloseIdleConnections()
+ }
+ }
+ proxyClients = make(map[string]*http.Client)
+}
+
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
return http.DefaultClient, nil
}
+ proxyClientLock.Lock()
+ if client, ok := proxyClients[proxyURL]; ok {
+ proxyClientLock.Unlock()
+ return client, nil
+ }
+ proxyClientLock.Unlock()
+
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, err
@@ -41,11 +65,16 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
switch parsedURL.Scheme {
case "http", "https":
- return &http.Client{
+ client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(parsedURL),
},
- }, nil
+ }
+ client.Timeout = time.Duration(common.RelayTimeout) * time.Second
+ proxyClientLock.Lock()
+ proxyClients[proxyURL] = client
+ proxyClientLock.Unlock()
+ return client, nil
case "socks5", "socks5h":
// 获取认证信息
@@ -67,15 +96,20 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
return nil, err
}
- return &http.Client{
+ client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
},
- }, nil
+ }
+ client.Timeout = time.Duration(common.RelayTimeout) * time.Second
+ proxyClientLock.Lock()
+ proxyClients[proxyURL] = client
+ proxyClientLock.Unlock()
+ return client, nil
default:
- return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
+ return nil, fmt.Errorf("unsupported proxy scheme: %s, must be http, https, socks5 or socks5h", parsedURL.Scheme)
}
}
diff --git a/service/passkey/service.go b/service/passkey/service.go
new file mode 100644
index 00000000..dc8da0cc
--- /dev/null
+++ b/service/passkey/service.go
@@ -0,0 +1,177 @@
+package passkey
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "one-api/common"
+ "one-api/setting/system_setting"
+
+ "github.com/go-webauthn/webauthn/protocol"
+ webauthn "github.com/go-webauthn/webauthn/webauthn"
+)
+
+const (
+ RegistrationSessionKey = "passkey_registration_session"
+ LoginSessionKey = "passkey_login_session"
+ VerifySessionKey = "passkey_verify_session"
+)
+
+// BuildWebAuthn constructs a WebAuthn instance using the current passkey settings and request context.
+func BuildWebAuthn(r *http.Request) (*webauthn.WebAuthn, error) {
+ settings := system_setting.GetPasskeySettings()
+ if settings == nil {
+ return nil, errors.New("未找到 Passkey 设置")
+ }
+
+ displayName := strings.TrimSpace(settings.RPDisplayName)
+ if displayName == "" {
+ displayName = common.SystemName
+ }
+
+ origins, err := resolveOrigins(r, settings)
+ if err != nil {
+ return nil, err
+ }
+
+ rpID, err := resolveRPID(r, settings, origins)
+ if err != nil {
+ return nil, err
+ }
+
+ selection := protocol.AuthenticatorSelection{
+ ResidentKey: protocol.ResidentKeyRequirementRequired,
+ RequireResidentKey: protocol.ResidentKeyRequired(),
+ UserVerification: protocol.UserVerificationRequirement(settings.UserVerification),
+ }
+ if selection.UserVerification == "" {
+ selection.UserVerification = protocol.VerificationPreferred
+ }
+ if attachment := strings.TrimSpace(settings.AttachmentPreference); attachment != "" {
+ selection.AuthenticatorAttachment = protocol.AuthenticatorAttachment(attachment)
+ }
+
+ config := &webauthn.Config{
+ RPID: rpID,
+ RPDisplayName: displayName,
+ RPOrigins: origins,
+ AuthenticatorSelection: selection,
+ Debug: common.DebugEnabled,
+ Timeouts: webauthn.TimeoutsConfig{
+ Login: webauthn.TimeoutConfig{
+ Enforce: true,
+ Timeout: 2 * time.Minute,
+ TimeoutUVD: 2 * time.Minute,
+ },
+ Registration: webauthn.TimeoutConfig{
+ Enforce: true,
+ Timeout: 2 * time.Minute,
+ TimeoutUVD: 2 * time.Minute,
+ },
+ },
+ }
+
+ return webauthn.New(config)
+}
+
+func resolveOrigins(r *http.Request, settings *system_setting.PasskeySettings) ([]string, error) {
+ originsStr := strings.TrimSpace(settings.Origins)
+ if originsStr != "" {
+ originList := strings.Split(originsStr, ",")
+ origins := make([]string, 0, len(originList))
+ for _, origin := range originList {
+ trimmed := strings.TrimSpace(origin)
+ if trimmed == "" {
+ continue
+ }
+ if !settings.AllowInsecureOrigin && strings.HasPrefix(strings.ToLower(trimmed), "http://") {
+ return nil, fmt.Errorf("Passkey 不允许使用不安全的 Origin: %s", trimmed)
+ }
+ origins = append(origins, trimmed)
+ }
+ if len(origins) == 0 {
+ // 如果配置了Origins但过滤后为空,使用自动推导
+ goto autoDetect
+ }
+ return origins, nil
+ }
+
+autoDetect:
+ scheme := detectScheme(r)
+ if scheme == "http" && !settings.AllowInsecureOrigin && r.Host != "localhost" && r.Host != "127.0.0.1" && !strings.HasPrefix(r.Host, "127.0.0.1:") && !strings.HasPrefix(r.Host, "localhost:") {
+ return nil, fmt.Errorf("Passkey 仅支持 HTTPS,当前访问: %s://%s,请在 Passkey 设置中允许不安全 Origin 或配置 HTTPS", scheme, r.Host)
+ }
+ // 优先使用请求的完整Host(包含端口)
+ host := r.Host
+
+ // 如果无法从请求获取Host,尝试从ServerAddress获取
+ if host == "" && system_setting.ServerAddress != "" {
+ if parsed, err := url.Parse(system_setting.ServerAddress); err == nil && parsed.Host != "" {
+ host = parsed.Host
+ if scheme == "" && parsed.Scheme != "" {
+ scheme = parsed.Scheme
+ }
+ }
+ }
+ if host == "" {
+ return nil, fmt.Errorf("无法确定 Passkey 的 Origin,请在系统设置或 Passkey 设置中指定。当前 Host: '%s', ServerAddress: '%s'", r.Host, system_setting.ServerAddress)
+ }
+ if scheme == "" {
+ scheme = "https"
+ }
+ origin := fmt.Sprintf("%s://%s", scheme, host)
+ return []string{origin}, nil
+}
+
+func resolveRPID(r *http.Request, settings *system_setting.PasskeySettings, origins []string) (string, error) {
+ rpID := strings.TrimSpace(settings.RPID)
+ if rpID != "" {
+ return hostWithoutPort(rpID), nil
+ }
+ if len(origins) == 0 {
+ return "", errors.New("Passkey 未配置 Origin,无法推导 RPID")
+ }
+ parsed, err := url.Parse(origins[0])
+ if err != nil {
+ return "", fmt.Errorf("无法解析 Passkey Origin: %w", err)
+ }
+ return hostWithoutPort(parsed.Host), nil
+}
+
+func hostWithoutPort(host string) string {
+ host = strings.TrimSpace(host)
+ if host == "" {
+ return ""
+ }
+ if strings.Contains(host, ":") {
+ if host, _, err := net.SplitHostPort(host); err == nil {
+ return host
+ }
+ }
+ return host
+}
+
+func detectScheme(r *http.Request) string {
+ if r == nil {
+ return ""
+ }
+ if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
+ parts := strings.Split(proto, ",")
+ return strings.ToLower(strings.TrimSpace(parts[0]))
+ }
+ if r.TLS != nil {
+ return "https"
+ }
+ if r.URL != nil && r.URL.Scheme != "" {
+ return strings.ToLower(r.URL.Scheme)
+ }
+ if r.Header.Get("X-Forwarded-Protocol") != "" {
+ return strings.ToLower(strings.TrimSpace(r.Header.Get("X-Forwarded-Protocol")))
+ }
+ return "http"
+}
diff --git a/service/passkey/session.go b/service/passkey/session.go
new file mode 100644
index 00000000..15e61932
--- /dev/null
+++ b/service/passkey/session.go
@@ -0,0 +1,50 @@
+package passkey
+
+import (
+ "encoding/json"
+ "errors"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ webauthn "github.com/go-webauthn/webauthn/webauthn"
+)
+
+var errSessionNotFound = errors.New("Passkey 会话不存在或已过期")
+
+func SaveSessionData(c *gin.Context, key string, data *webauthn.SessionData) error {
+ session := sessions.Default(c)
+ if data == nil {
+ session.Delete(key)
+ return session.Save()
+ }
+ payload, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ session.Set(key, string(payload))
+ return session.Save()
+}
+
+func PopSessionData(c *gin.Context, key string) (*webauthn.SessionData, error) {
+ session := sessions.Default(c)
+ raw := session.Get(key)
+ if raw == nil {
+ return nil, errSessionNotFound
+ }
+ session.Delete(key)
+ _ = session.Save()
+ var data webauthn.SessionData
+ switch value := raw.(type) {
+ case string:
+ if err := json.Unmarshal([]byte(value), &data); err != nil {
+ return nil, err
+ }
+ case []byte:
+ if err := json.Unmarshal(value, &data); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, errors.New("Passkey 会话格式无效")
+ }
+ return &data, nil
+}
diff --git a/service/passkey/user.go b/service/passkey/user.go
new file mode 100644
index 00000000..8b8c559f
--- /dev/null
+++ b/service/passkey/user.go
@@ -0,0 +1,71 @@
+package passkey
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "one-api/model"
+
+ webauthn "github.com/go-webauthn/webauthn/webauthn"
+)
+
+type WebAuthnUser struct {
+ user *model.User
+ credential *model.PasskeyCredential
+}
+
+func NewWebAuthnUser(user *model.User, credential *model.PasskeyCredential) *WebAuthnUser {
+ return &WebAuthnUser{user: user, credential: credential}
+}
+
+func (u *WebAuthnUser) WebAuthnID() []byte {
+ if u == nil || u.user == nil {
+ return nil
+ }
+ return []byte(strconv.Itoa(u.user.Id))
+}
+
+func (u *WebAuthnUser) WebAuthnName() string {
+ if u == nil || u.user == nil {
+ return ""
+ }
+ name := strings.TrimSpace(u.user.Username)
+ if name == "" {
+ return fmt.Sprintf("user-%d", u.user.Id)
+ }
+ return name
+}
+
+func (u *WebAuthnUser) WebAuthnDisplayName() string {
+ if u == nil || u.user == nil {
+ return ""
+ }
+ display := strings.TrimSpace(u.user.DisplayName)
+ if display != "" {
+ return display
+ }
+ return u.WebAuthnName()
+}
+
+func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
+ if u == nil || u.credential == nil {
+ return nil
+ }
+ cred := u.credential.ToWebAuthnCredential()
+ return []webauthn.Credential{cred}
+}
+
+func (u *WebAuthnUser) ModelUser() *model.User {
+ if u == nil {
+ return nil
+ }
+ return u.user
+}
+
+func (u *WebAuthnUser) PasskeyCredential() *model.PasskeyCredential {
+ if u == nil {
+ return nil
+ }
+ return u.credential
+}
diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go
index 86b04e52..0cf53513 100644
--- a/service/pre_consume_quota.go
+++ b/service/pre_consume_quota.go
@@ -13,13 +13,13 @@ import (
"github.com/gin-gonic/gin"
)
-func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
- if preConsumedQuota != 0 {
- logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
+func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+ if relayInfo.FinalPreConsumedQuota != 0 {
+ logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
gopool.Go(func() {
relayInfoCopy := *relayInfo
- err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+ err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
if err != nil {
common.SysLog("error return pre-consumed quota: " + err.Error())
}
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
// PreConsumeQuota checks if the user has enough quota to pre-consume.
// It returns the pre-consumed quota if successful, or an error if not.
-func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
+func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
- return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+ return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota <= 0 {
- return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
if userQuota-preConsumedQuota < 0 {
- return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
trustQuota := common.GetTrustQuota()
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 {
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
- return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
- return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+ return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
}
relayInfo.FinalPreConsumedQuota = preConsumedQuota
- return preConsumedQuota, nil
+ return nil
}
diff --git a/service/quota.go b/service/quota.go
index e078a1ad..43c4024a 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -11,8 +11,8 @@ import (
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
- "one-api/setting"
"one-api/setting/ratio_setting"
+ "one-api/setting/system_setting"
"one-api/types"
"strings"
"time"
@@ -534,7 +534,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
}
if quotaTooLow {
prompt := "您的额度即将用尽"
- topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
+ topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
// 根据通知方式生成不同的内容格式
var content string
@@ -549,8 +549,11 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
// Bark推送使用简短文本,不支持HTML
content = "{{value}},剩余额度:{{value}},请及时充值"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
+ } else if notifyType == dto.NotifyTypeGotify {
+ content = "{{value}},当前剩余额度为 {{value}},请及时充值。"
+ values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
} else {
- // 默认内容格式,适用于Email和Webhook
+ // 默认内容格式,适用于Email和Webhook(支持HTML)
content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}
}
diff --git a/service/token_counter.go b/service/token_counter.go
index da56523f..be5c2e80 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -336,7 +336,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
for i, file := range meta.Files {
switch file.FileType {
case types.FileTypeImage:
- if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") {
+ if info.RelayFormat == types.RelayFormatGemini {
tkm += 256
} else {
token, err := getImageToken(file, model, info.IsStream)
diff --git a/service/user_notify.go b/service/user_notify.go
index c4a3ea91..0f92e7d7 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -1,13 +1,15 @@
package service
import (
+ "bytes"
+ "encoding/json"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/model"
- "one-api/setting"
+ "one-api/setting/system_setting"
"strings"
)
@@ -37,13 +39,16 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
switch notifyType {
case dto.NotifyTypeEmail:
- // check setting email
- userEmail = userSetting.NotificationEmail
- if userEmail == "" {
+ // 优先使用设置中的通知邮箱,如果为空则使用用户的默认邮箱
+ emailToUse := userSetting.NotificationEmail
+ if emailToUse == "" {
+ emailToUse = userEmail
+ }
+ if emailToUse == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
- return sendEmailNotify(userEmail, data)
+ return sendEmailNotify(emailToUse, data)
case dto.NotifyTypeWebhook:
webhookURLStr := userSetting.WebhookUrl
if webhookURLStr == "" {
@@ -61,6 +66,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
return nil
}
return sendBarkNotify(barkURL, data)
+ case dto.NotifyTypeGotify:
+ gotifyUrl := userSetting.GotifyUrl
+ gotifyToken := userSetting.GotifyToken
+ if gotifyUrl == "" || gotifyToken == "" {
+ common.SysLog(fmt.Sprintf("user %d has no gotify url or token, skip sending gotify", userId))
+ return nil
+ }
+ return sendGotifyNotify(gotifyUrl, gotifyToken, userSetting.GotifyPriority, data)
}
return nil
}
@@ -91,11 +104,11 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
var resp *http.Response
var err error
- if setting.EnableWorker() {
+ if system_setting.EnableWorker() {
// 使用worker发送请求
workerReq := &WorkerRequest{
URL: finalURL,
- Key: setting.WorkerValidKey,
+ Key: system_setting.WorkerValidKey,
Method: http.MethodGet,
Headers: map[string]string{
"User-Agent": "OneAPI-Bark-Notify/1.0",
@@ -113,6 +126,12 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
}
} else {
+ // SSRF防护:验证Bark URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return fmt.Errorf("request reject: %v", err)
+ }
+
// 直接发送请求
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
if err != nil {
@@ -138,3 +157,98 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
return nil
}
+
+func sendGotifyNotify(gotifyUrl string, gotifyToken string, priority int, data dto.Notify) error {
+ // 处理占位符
+ content := data.Content
+ for _, value := range data.Values {
+ content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
+ }
+
+ // 构建完整的 Gotify API URL
+ // 确保 URL 以 /message 结尾
+ finalURL := strings.TrimSuffix(gotifyUrl, "/") + "/message?token=" + url.QueryEscape(gotifyToken)
+
+ // Gotify优先级范围0-10,如果超出范围则使用默认值5
+ if priority < 0 || priority > 10 {
+ priority = 5
+ }
+
+ // 构建 JSON payload
+ type GotifyMessage struct {
+ Title string `json:"title"`
+ Message string `json:"message"`
+ Priority int `json:"priority"`
+ }
+
+ payload := GotifyMessage{
+ Title: data.Title,
+ Message: content,
+ Priority: priority,
+ }
+
+ // 序列化为 JSON
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal gotify payload: %v", err)
+ }
+
+ var req *http.Request
+ var resp *http.Response
+
+ if system_setting.EnableWorker() {
+ // 使用worker发送请求
+ workerReq := &WorkerRequest{
+ URL: finalURL,
+ Key: system_setting.WorkerValidKey,
+ Method: http.MethodPost,
+ Headers: map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "User-Agent": "OneAPI-Gotify-Notify/1.0",
+ },
+ Body: payloadBytes,
+ }
+
+ resp, err = DoWorkerRequest(workerReq)
+ if err != nil {
+ return fmt.Errorf("failed to send gotify request through worker: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("gotify request failed with status code: %d", resp.StatusCode)
+ }
+ } else {
+ // SSRF防护:验证Gotify URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return fmt.Errorf("request reject: %v", err)
+ }
+
+ // 直接发送请求
+ req, err = http.NewRequest(http.MethodPost, finalURL, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return fmt.Errorf("failed to create gotify request: %v", err)
+ }
+
+ // 设置请求头
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ req.Header.Set("User-Agent", "NewAPI-Gotify-Notify/1.0")
+
+ // 发送请求
+ client := GetHttpClient()
+ resp, err = client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send gotify request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("gotify request failed with status code: %d", resp.StatusCode)
+ }
+ }
+
+ return nil
+}
diff --git a/service/webhook.go b/service/webhook.go
index 8faccda3..c678b863 100644
--- a/service/webhook.go
+++ b/service/webhook.go
@@ -8,8 +8,9 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "one-api/common"
"one-api/dto"
- "one-api/setting"
+ "one-api/setting/system_setting"
"time"
)
@@ -56,11 +57,11 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
var req *http.Request
var resp *http.Response
- if setting.EnableWorker() {
+ if system_setting.EnableWorker() {
// 构建worker请求数据
workerReq := &WorkerRequest{
URL: webhookURL,
- Key: setting.WorkerValidKey,
+ Key: system_setting.WorkerValidKey,
Method: http.MethodPost,
Headers: map[string]string{
"Content-Type": "application/json",
@@ -86,6 +87,12 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
} else {
+ // SSRF防护:验证Webhook URL(非Worker模式)
+ fetchSetting := system_setting.GetFetchSetting()
+ if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
+ return fmt.Errorf("request reject: %v", err)
+ }
+
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return fmt.Errorf("failed to create webhook request: %v", err)
diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go
index 5412155f..f132fec8 100644
--- a/setting/model_setting/gemini.go
+++ b/setting/model_setting/gemini.go
@@ -26,7 +26,6 @@ var defaultGeminiSettings = GeminiSettings{
SupportedImagineModels: []string{
"gemini-2.0-flash-exp-image-generation",
"gemini-2.0-flash-exp",
- "gemini-2.5-flash-image-preview",
},
ThinkingAdapterEnabled: false,
ThinkingAdapterBudgetTokensPercentage: 0.6,
diff --git a/setting/operation_setting/payment_setting.go b/setting/operation_setting/payment_setting.go
new file mode 100644
index 00000000..c8df039c
--- /dev/null
+++ b/setting/operation_setting/payment_setting.go
@@ -0,0 +1,23 @@
+package operation_setting
+
+import "one-api/setting/config"
+
+type PaymentSetting struct {
+ AmountOptions []int `json:"amount_options"`
+ AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠
+}
+
+// 默认配置
+var paymentSetting = PaymentSetting{
+ AmountOptions: []int{10, 20, 50, 100, 200, 500},
+ AmountDiscount: map[int]float64{},
+}
+
+func init() {
+ // 注册到全局配置管理器
+ config.GlobalConfig.Register("payment_setting", &paymentSetting)
+}
+
+func GetPaymentSetting() *PaymentSetting {
+ return &paymentSetting
+}
diff --git a/setting/payment.go b/setting/operation_setting/payment_setting_old.go
similarity index 57%
rename from setting/payment.go
rename to setting/operation_setting/payment_setting_old.go
index 7fc5ad3f..a6313179 100644
--- a/setting/payment.go
+++ b/setting/operation_setting/payment_setting_old.go
@@ -1,6 +1,13 @@
-package setting
+/**
+此文件为旧版支付设置文件,如需增加新的参数、变量等,请在 payment_setting.go 中添加
+This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go
+*/
-import "encoding/json"
+package operation_setting
+
+import (
+ "one-api/common"
+)
var PayAddress = ""
var CustomCallbackAddress = ""
@@ -21,15 +28,21 @@ var PayMethods = []map[string]string{
"color": "rgba(var(--semi-green-5), 1)",
"type": "wxpay",
},
+ {
+ "name": "自定义1",
+ "color": "black",
+ "type": "custom1",
+ "min_topup": "50",
+ },
}
func UpdatePayMethodsByJsonString(jsonString string) error {
PayMethods = make([]map[string]string, 0)
- return json.Unmarshal([]byte(jsonString), &PayMethods)
+ return common.Unmarshal([]byte(jsonString), &PayMethods)
}
func PayMethods2JsonString() string {
- jsonBytes, err := json.Marshal(PayMethods)
+ jsonBytes, err := common.Marshal(PayMethods)
if err != nil {
return "[]"
}
diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go
index b87265ee..adb76bfc 100644
--- a/setting/operation_setting/tools.go
+++ b/setting/operation_setting/tools.go
@@ -10,6 +10,18 @@ const (
FileSearchPrice = 2.5
)
+const (
+ GPTImage1Low1024x1024 = 0.011
+ GPTImage1Low1024x1536 = 0.016
+ GPTImage1Low1536x1024 = 0.016
+ GPTImage1Medium1024x1024 = 0.042
+ GPTImage1Medium1024x1536 = 0.063
+ GPTImage1Medium1536x1024 = 0.063
+ GPTImage1High1024x1024 = 0.167
+ GPTImage1High1024x1536 = 0.25
+ GPTImage1High1536x1024 = 0.25
+)
+
const (
// Gemini Audio Input Price
Gemini25FlashPreviewInputAudioPrice = 1.00
@@ -17,6 +29,7 @@ const (
Gemini25FlashLitePreviewInputAudioPrice = 0.50
Gemini25FlashNativeAudioInputAudioPrice = 3.00
Gemini20FlashInputAudioPrice = 0.70
+ GeminiRoboticsER15InputAudioPrice = 1.00
)
const (
@@ -24,10 +37,6 @@ const (
ClaudeWebSearchPrice = 10.00
)
-const (
- Gemini25FlashImagePreviewImageOutputPrice = 30.00
-)
-
func GetClaudeWebSearchPricePerThousand() float64 {
return ClaudeWebSearchPrice
}
@@ -66,13 +75,36 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
return Gemini25FlashProductionInputAudioPrice
} else if strings.HasPrefix(modelName, "gemini-2.0-flash") {
return Gemini20FlashInputAudioPrice
+ } else if strings.HasPrefix(modelName, "gemini-robotics-er-1.5") {
+ return GeminiRoboticsER15InputAudioPrice
}
return 0
}
-func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 {
- if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
- return Gemini25FlashImagePreviewImageOutputPrice
+func GetGPTImage1PriceOnceCall(quality string, size string) float64 {
+ prices := map[string]map[string]float64{
+ "low": {
+ "1024x1024": GPTImage1Low1024x1024,
+ "1024x1536": GPTImage1Low1024x1536,
+ "1536x1024": GPTImage1Low1536x1024,
+ },
+ "medium": {
+ "1024x1024": GPTImage1Medium1024x1024,
+ "1024x1536": GPTImage1Medium1024x1536,
+ "1536x1024": GPTImage1Medium1536x1024,
+ },
+ "high": {
+ "1024x1024": GPTImage1High1024x1024,
+ "1024x1536": GPTImage1High1024x1536,
+ "1536x1024": GPTImage1High1536x1024,
+ },
}
- return 0
+
+ if qualityMap, exists := prices[quality]; exists {
+ if price, exists := qualityMap[size]; exists {
+ return price
+ }
+ }
+
+ return GPTImage1High1024x1024
}
diff --git a/setting/payment_stripe.go b/setting/payment_stripe.go
index 80d877df..d97120c8 100644
--- a/setting/payment_stripe.go
+++ b/setting/payment_stripe.go
@@ -5,3 +5,4 @@ var StripeWebhookSecret = ""
var StripePriceId = ""
var StripeUnitPrice = 8.0
var StripeMinTopUp = 1
+var StripePromotionCodesEnabled = false
diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go
index 5993cdee..8e4b227a 100644
--- a/setting/ratio_setting/cache_ratio.go
+++ b/setting/ratio_setting/cache_ratio.go
@@ -52,6 +52,8 @@ var defaultCacheRatio = map[string]float64{
"claude-opus-4-20250514-thinking": 0.1,
"claude-opus-4-1-20250805": 0.1,
"claude-opus-4-1-20250805-thinking": 0.1,
+ "claude-sonnet-4-5-20250929": 0.1,
+ "claude-sonnet-4-5-20250929-thinking": 0.1,
}
var defaultCreateCacheRatio = map[string]float64{
@@ -69,6 +71,8 @@ var defaultCreateCacheRatio = map[string]float64{
"claude-opus-4-20250514-thinking": 1.25,
"claude-opus-4-1-20250805": 1.25,
"claude-opus-4-1-20250805-thinking": 1.25,
+ "claude-sonnet-4-5-20250929": 1.25,
+ "claude-sonnet-4-5-20250929-thinking": 1.25,
}
//var defaultCreateCacheRatio = map[string]float64{}
diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go
index 1a1b0afa..5e55576f 100644
--- a/setting/ratio_setting/model_ratio.go
+++ b/setting/ratio_setting/model_ratio.go
@@ -141,6 +141,7 @@ var defaultModelRatio = map[string]float64{
"claude-3-7-sonnet-20250219": 1.5,
"claude-3-7-sonnet-20250219-thinking": 1.5,
"claude-sonnet-4-20250514": 1.5,
+ "claude-sonnet-4-5-20250929": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"claude-opus-4-20250514": 7.5,
"claude-opus-4-1-20250805": 7.5,
@@ -178,7 +179,8 @@ var defaultModelRatio = map[string]float64{
"gemini-2.5-flash-lite-preview-thinking-*": 0.05,
"gemini-2.5-flash-lite-preview-06-17": 0.05,
"gemini-2.5-flash": 0.15,
- "gemini-2.5-flash-image-preview": 0.15, // $0.30(text/image) / 1M tokens
+ "gemini-robotics-er-1.5-preview": 0.15,
+ "gemini-embedding-001": 0.075,
"text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
@@ -251,6 +253,17 @@ var defaultModelRatio = map[string]float64{
"grok-vision-beta": 2.5,
"grok-3-fast-beta": 2.5,
"grok-3-mini-fast-beta": 0.3,
+ // submodel
+ "NousResearch/Hermes-4-405B-FP8": 0.8,
+ "Qwen/Qwen3-235B-A22B-Thinking-2507": 0.6,
+ "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8": 0.8,
+ "Qwen/Qwen3-235B-A22B-Instruct-2507": 0.3,
+ "zai-org/GLM-4.5-FP8": 0.8,
+ "openai/gpt-oss-120b": 0.5,
+ "deepseek-ai/DeepSeek-R1-0528": 0.8,
+ "deepseek-ai/DeepSeek-R1": 0.8,
+ "deepseek-ai/DeepSeek-V3-0324": 0.8,
+ "deepseek-ai/DeepSeek-V3.1": 0.8,
}
var defaultModelPrice = map[string]float64{
@@ -279,6 +292,18 @@ var defaultModelPrice = map[string]float64{
"mj_upload": 0.05,
}
+var defaultAudioRatio = map[string]float64{
+ "gpt-4o-audio-preview": 16,
+ "gpt-4o-mini-audio-preview": 66.67,
+ "gpt-4o-realtime-preview": 8,
+ "gpt-4o-mini-realtime-preview": 16.67,
+}
+
+var defaultAudioCompletionRatio = map[string]float64{
+ "gpt-4o-realtime": 2,
+ "gpt-4o-mini-realtime": 2,
+}
+
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
@@ -294,11 +319,10 @@ var (
)
var defaultCompletionRatio = map[string]float64{
- "gpt-4-gizmo-*": 2,
- "gpt-4o-gizmo-*": 3,
- "gpt-4-all": 2,
- "gpt-image-1": 8,
- "gemini-2.5-flash-image-preview": 8.3333333333,
+ "gpt-4-gizmo-*": 2,
+ "gpt-4o-gizmo-*": 3,
+ "gpt-4-all": 2,
+ "gpt-image-1": 8,
}
// InitRatioSettings initializes all model related settings maps
@@ -328,6 +352,15 @@ func InitRatioSettings() {
imageRatioMap = defaultImageRatio
imageRatioMapMutex.Unlock()
+ // initialize audioRatioMap
+ audioRatioMapMutex.Lock()
+ audioRatioMap = defaultAudioRatio
+ audioRatioMapMutex.Unlock()
+
+ // initialize audioCompletionRatioMap
+ audioCompletionRatioMapMutex.Lock()
+ audioCompletionRatioMap = defaultAudioCompletionRatio
+ audioCompletionRatioMapMutex.Unlock()
}
func GetModelPriceMap() map[string]float64 {
@@ -419,6 +452,18 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
+func GetDefaultImageRatioMap() map[string]float64 {
+ return defaultImageRatio
+}
+
+func GetDefaultAudioRatioMap() map[string]float64 {
+ return defaultAudioRatio
+}
+
+func GetDefaultAudioCompletionRatioMap() map[string]float64 {
+ return defaultAudioCompletionRatio
+}
+
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
@@ -469,7 +514,6 @@ func GetCompletionRatio(name string) float64 {
}
func getHardcodedCompletionModelRatio(name string) (float64, bool) {
- lowercaseName := strings.ToLower(name)
isReservedModel := strings.HasSuffix(name, "-all") || strings.HasSuffix(name, "-gizmo-*")
if isReservedModel {
@@ -544,6 +588,8 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
return 4, false
}
return 2.5 / 0.3, false
+ } else if strings.HasPrefix(name, "gemini-robotics-er-1.5") {
+ return 2.5 / 0.3, false
}
return 4, false
}
@@ -562,9 +608,6 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
}
}
// hint 只给官方上4倍率,由于开源模型供应商自行定价,不对其进行补全倍率进行强制对齐
- if lowercaseName == "deepseek-chat" || lowercaseName == "deepseek-reasoner" {
- return 4, true
- }
if strings.HasPrefix(name, "ERNIE-Speed-") {
return 2, true
} else if strings.HasPrefix(name, "ERNIE-Lite-") {
@@ -586,32 +629,22 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
}
func GetAudioRatio(name string) float64 {
- if strings.Contains(name, "-realtime") {
- if strings.HasSuffix(name, "gpt-4o-realtime-preview") {
- return 8
- } else if strings.Contains(name, "gpt-4o-mini-realtime-preview") {
- return 10 / 0.6
- } else {
- return 20
- }
- }
- if strings.Contains(name, "-audio") {
- if strings.HasPrefix(name, "gpt-4o-audio-preview") {
- return 40 / 2.5
- } else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") {
- return 10 / 0.15
- } else {
- return 40
- }
+ audioRatioMapMutex.RLock()
+ defer audioRatioMapMutex.RUnlock()
+ name = FormatMatchingModelName(name)
+ if ratio, ok := audioRatioMap[name]; ok {
+ return ratio
}
return 20
}
func GetAudioCompletionRatio(name string) float64 {
- if strings.HasPrefix(name, "gpt-4o-realtime") {
- return 2
- } else if strings.HasPrefix(name, "gpt-4o-mini-realtime") {
- return 2
+ audioCompletionRatioMapMutex.RLock()
+ defer audioCompletionRatioMapMutex.RUnlock()
+ name = FormatMatchingModelName(name)
+ if ratio, ok := audioCompletionRatioMap[name]; ok {
+
+ return ratio
}
return 2
}
@@ -632,6 +665,14 @@ var defaultImageRatio = map[string]float64{
}
var imageRatioMap map[string]float64
var imageRatioMapMutex sync.RWMutex
+var (
+ audioRatioMap map[string]float64 = nil
+ audioRatioMapMutex = sync.RWMutex{}
+)
+var (
+ audioCompletionRatioMap map[string]float64 = nil
+ audioCompletionRatioMapMutex = sync.RWMutex{}
+)
func ImageRatio2JSONString() string {
imageRatioMapMutex.RLock()
@@ -660,6 +701,71 @@ func GetImageRatio(name string) (float64, bool) {
return ratio, true
}
+func AudioRatio2JSONString() string {
+ audioRatioMapMutex.RLock()
+ defer audioRatioMapMutex.RUnlock()
+ jsonBytes, err := common.Marshal(audioRatioMap)
+ if err != nil {
+ common.SysError("error marshalling audio ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateAudioRatioByJSONString(jsonStr string) error {
+
+ tmp := make(map[string]float64)
+ if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
+ return err
+ }
+ audioRatioMapMutex.Lock()
+ audioRatioMap = tmp
+ audioRatioMapMutex.Unlock()
+ InvalidateExposedDataCache()
+ return nil
+}
+
+func GetAudioRatioCopy() map[string]float64 {
+ audioRatioMapMutex.RLock()
+ defer audioRatioMapMutex.RUnlock()
+ copyMap := make(map[string]float64, len(audioRatioMap))
+ for k, v := range audioRatioMap {
+ copyMap[k] = v
+ }
+ return copyMap
+}
+
+func AudioCompletionRatio2JSONString() string {
+ audioCompletionRatioMapMutex.RLock()
+ defer audioCompletionRatioMapMutex.RUnlock()
+ jsonBytes, err := common.Marshal(audioCompletionRatioMap)
+ if err != nil {
+ common.SysError("error marshalling audio completion ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
+ tmp := make(map[string]float64)
+ if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
+ return err
+ }
+ audioCompletionRatioMapMutex.Lock()
+ audioCompletionRatioMap = tmp
+ audioCompletionRatioMapMutex.Unlock()
+ InvalidateExposedDataCache()
+ return nil
+}
+
+func GetAudioCompletionRatioCopy() map[string]float64 {
+ audioCompletionRatioMapMutex.RLock()
+ defer audioCompletionRatioMapMutex.RUnlock()
+ copyMap := make(map[string]float64, len(audioCompletionRatioMap))
+ for k, v := range audioCompletionRatioMap {
+ copyMap[k] = v
+ }
+ return copyMap
+}
+
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
diff --git a/setting/system_setting/fetch_setting.go b/setting/system_setting/fetch_setting.go
new file mode 100644
index 00000000..c41b930a
--- /dev/null
+++ b/setting/system_setting/fetch_setting.go
@@ -0,0 +1,34 @@
+package system_setting
+
+import "one-api/setting/config"
+
+type FetchSetting struct {
+ EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护
+ AllowPrivateIp bool `json:"allow_private_ip"`
+ DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式
+ IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式
+ DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com
+ IpList []string `json:"ip_list"` // CIDR format
+ AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000
+ ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性)
+}
+
+var defaultFetchSetting = FetchSetting{
+ EnableSSRFProtection: true, // 默认开启SSRF防护
+ AllowPrivateIp: false,
+ DomainFilterMode: false,
+ IpFilterMode: false,
+ DomainList: []string{},
+ IpList: []string{},
+ AllowedPorts: []string{"80", "443", "8080", "8443"},
+ ApplyIPFilterForDomain: false,
+}
+
+func init() {
+ // 注册到全局配置管理器
+ config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting)
+}
+
+func GetFetchSetting() *FetchSetting {
+ return &defaultFetchSetting
+}
diff --git a/setting/system_setting/passkey.go b/setting/system_setting/passkey.go
new file mode 100644
index 00000000..a0766a67
--- /dev/null
+++ b/setting/system_setting/passkey.go
@@ -0,0 +1,49 @@
+package system_setting
+
+import (
+ "net/url"
+ "one-api/common"
+ "one-api/setting/config"
+ "strings"
+)
+
+type PasskeySettings struct {
+ Enabled bool `json:"enabled"`
+ RPDisplayName string `json:"rp_display_name"`
+ RPID string `json:"rp_id"`
+ Origins string `json:"origins"`
+ AllowInsecureOrigin bool `json:"allow_insecure_origin"`
+ UserVerification string `json:"user_verification"`
+ AttachmentPreference string `json:"attachment_preference"`
+}
+
+var defaultPasskeySettings = PasskeySettings{
+ Enabled: false,
+ RPDisplayName: common.SystemName,
+ RPID: "",
+ Origins: "",
+ AllowInsecureOrigin: false,
+ UserVerification: "preferred",
+ AttachmentPreference: "",
+}
+
+func init() {
+ config.GlobalConfig.Register("passkey", &defaultPasskeySettings)
+}
+
+func GetPasskeySettings() *PasskeySettings {
+ if defaultPasskeySettings.RPID == "" && ServerAddress != "" {
+ // 从ServerAddress提取域名作为RPID
+ // ServerAddress可能是 "https://newapi.pro" 这种格式
+ serverAddr := strings.TrimSpace(ServerAddress)
+ if parsed, err := url.Parse(serverAddr); err == nil && parsed.Host != "" {
+ defaultPasskeySettings.RPID = parsed.Host
+ } else {
+ defaultPasskeySettings.RPID = serverAddr
+ }
+ }
+ if defaultPasskeySettings.Origins == "" || defaultPasskeySettings.Origins == "[]" {
+ defaultPasskeySettings.Origins = ServerAddress
+ }
+ return &defaultPasskeySettings
+}
diff --git a/setting/system_setting.go b/setting/system_setting/system_setting_old.go
similarity index 89%
rename from setting/system_setting.go
rename to setting/system_setting/system_setting_old.go
index c37a6123..4e0f1a50 100644
--- a/setting/system_setting.go
+++ b/setting/system_setting/system_setting_old.go
@@ -1,4 +1,4 @@
-package setting
+package system_setting
var ServerAddress = "http://localhost:3000"
var WorkerUrl = ""
diff --git a/types/error.go b/types/error.go
index f653e9a2..a42e8438 100644
--- a/types/error.go
+++ b/types/error.go
@@ -122,6 +122,9 @@ func (e *NewAPIError) MaskSensitiveError() string {
return string(e.errorCode)
}
errStr := e.Err.Error()
+ if e.errorCode == ErrorCodeCountTokenFailed {
+ return errStr
+ }
return common.MaskSensitiveInfo(errStr)
}
@@ -153,8 +156,9 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
Code: e.errorCode,
}
}
-
- result.Message = common.MaskSensitiveInfo(result.Message)
+ if e.errorCode != ErrorCodeCountTokenFailed {
+ result.Message = common.MaskSensitiveInfo(result.Message)
+ }
return result
}
@@ -178,13 +182,23 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
Type: string(e.errorType),
}
}
- result.Message = common.MaskSensitiveInfo(result.Message)
+ if e.errorCode != ErrorCodeCountTokenFailed {
+ result.Message = common.MaskSensitiveInfo(result.Message)
+ }
return result
}
type NewAPIErrorOptions func(*NewAPIError)
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
+ var newErr *NewAPIError
+ // 保留深层传递的 new err
+ if errors.As(err, &newErr) {
+ for _, op := range ops {
+ op(newErr)
+ }
+ return newErr
+ }
e := &NewAPIError{
Err: err,
RelayError: nil,
@@ -199,8 +213,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
}
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
- if errorCode == ErrorCodeDoRequestFailed {
- err = errors.New("upstream error: do request failed")
+ var newErr *NewAPIError
+ // 保留深层传递的 new err
+ if errors.As(err, &newErr) {
+ if newErr.RelayError == nil {
+ openaiError := OpenAIError{
+ Message: newErr.Error(),
+ Type: string(errorCode),
+ Code: errorCode,
+ }
+ newErr.RelayError = openaiError
+ }
+ for _, op := range ops {
+ op(newErr)
+ }
+ return newErr
}
openaiError := OpenAIError{
Message: err.Error(),
@@ -305,6 +332,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
}
}
+func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
+ return func(e *NewAPIError) {
+ if common.DebugEnabled {
+ fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
+ }
+ e.Err = errors.New(replaceStr)
+ }
+}
+
func IsRecordErrorLog(e *NewAPIError) bool {
if e == nil {
return false
diff --git a/types/price_data.go b/types/price_data.go
index f6a92d7e..ec7fcdfe 100644
--- a/types/price_data.go
+++ b/types/price_data.go
@@ -15,6 +15,8 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
+ AudioRatio float64
+ AudioCompletionRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
@@ -27,5 +29,5 @@ type PerCallPriceData struct {
}
func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
}
diff --git a/web/index.html b/web/index.html
index 09d87ae1..df6b0e39 100644
--- a/web/index.html
+++ b/web/index.html
@@ -10,6 +10,7 @@
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
/>
New API
+
diff --git a/web/jsconfig.json b/web/jsconfig.json
new file mode 100644
index 00000000..ced4d054
--- /dev/null
+++ b/web/jsconfig.json
@@ -0,0 +1,9 @@
+{
+ "compilerOptions": {
+ "baseUrl": "./",
+ "paths": {
+ "@/*": ["src/*"]
+ }
+ },
+ "include": ["src/**/*"]
+}
\ No newline at end of file
diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx
index 32087ab0..828e7178 100644
--- a/web/src/components/auth/LoginForm.jsx
+++ b/web/src/components/auth/LoginForm.jsx
@@ -32,6 +32,9 @@ import {
onGitHubOAuthClicked,
onOIDCClicked,
onLinuxDOOAuthClicked,
+ prepareCredentialRequestOptions,
+ buildAssertionResult,
+ isPasskeySupported,
} from '../../helpers';
import Turnstile from 'react-turnstile';
import { Button, Card, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui';
@@ -39,7 +42,7 @@ import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import TelegramLoginButton from 'react-telegram-login';
-import { IconGithubLogo, IconMail, IconLock } from '@douyinfe/semi-icons';
+import { IconGithubLogo, IconMail, IconLock, IconKey } from '@douyinfe/semi-icons';
import OIDCIcon from '../common/logo/OIDCIcon';
import WeChatIcon from '../common/logo/WeChatIcon';
import LinuxDoIcon from '../common/logo/LinuxDoIcon';
@@ -74,6 +77,8 @@ const LoginForm = () => {
useState(false);
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
const [showTwoFA, setShowTwoFA] = useState(false);
+ const [passkeySupported, setPasskeySupported] = useState(false);
+ const [passkeyLoading, setPasskeyLoading] = useState(false);
const logo = getLogo();
const systemName = getSystemName();
@@ -95,6 +100,12 @@ const LoginForm = () => {
}
}, [status]);
+ useEffect(() => {
+ isPasskeySupported()
+ .then(setPasskeySupported)
+ .catch(() => setPasskeySupported(false));
+ }, []);
+
useEffect(() => {
if (searchParams.get('expired')) {
showError(t('未登录或登录已过期,请重新登录'));
@@ -266,6 +277,55 @@ const LoginForm = () => {
setEmailLoginLoading(false);
};
+ const handlePasskeyLogin = async () => {
+ if (!passkeySupported) {
+ showInfo('当前环境无法使用 Passkey 登录');
+ return;
+ }
+ if (!window.PublicKeyCredential) {
+ showInfo('当前浏览器不支持 Passkey');
+ return;
+ }
+
+ setPasskeyLoading(true);
+ try {
+ const beginRes = await API.post('/api/user/passkey/login/begin');
+ const { success, message, data } = beginRes.data;
+ if (!success) {
+ showError(message || '无法发起 Passkey 登录');
+ return;
+ }
+
+ const publicKeyOptions = prepareCredentialRequestOptions(data?.options || data?.publicKey || data);
+ const assertion = await navigator.credentials.get({ publicKey: publicKeyOptions });
+ const payload = buildAssertionResult(assertion);
+ if (!payload) {
+ showError('Passkey 验证失败,请重试');
+ return;
+ }
+
+ const finishRes = await API.post('/api/user/passkey/login/finish', payload);
+ const finish = finishRes.data;
+ if (finish.success) {
+ userDispatch({ type: 'login', payload: finish.data });
+ setUserData(finish.data);
+ updateAPI();
+ showSuccess('登录成功!');
+ navigate('/console');
+ } else {
+ showError(finish.message || 'Passkey 登录失败,请重试');
+ }
+ } catch (error) {
+ if (error?.name === 'AbortError') {
+ showInfo('已取消 Passkey 登录');
+ } else {
+ showError('Passkey 登录失败,请重试');
+ }
+ } finally {
+ setPasskeyLoading(false);
+ }
+ };
+
// 包装的重置密码点击处理
const handleResetPasswordClick = () => {
setResetPasswordLoading(true);
@@ -385,6 +445,19 @@ const LoginForm = () => {
)}
+ {status.passkey_login && passkeySupported && (
+
}
+ onClick={handlePasskeyLogin}
+ loading={passkeyLoading}
+ >
+