diff --git a/.env.example b/.env.example
index ea9061fb..f4b9d02e 100644
--- a/.env.example
+++ b/.env.example
@@ -57,6 +57,9 @@
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=300
+# TLS / HTTP 跳过验证设置
+# TLS_INSECURE_SKIP_VERIFY=false
+
# Gemini 识别图片 最大图片数量
# GEMINI_VISION_MAX_IMAGE_NUM=16
diff --git a/common/api_type.go b/common/api_type.go
index 4f5c1826..39c1fe9a 100644
--- a/common/api_type.go
+++ b/common/api_type.go
@@ -73,6 +73,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeMiniMax
case constant.ChannelTypeReplicate:
apiType = constant.APITypeReplicate
+ case constant.ChannelTypeCodex:
+ apiType = constant.APITypeCodex
}
if apiType == -1 {
return constant.APITypeOpenAI, false
diff --git a/common/constants.go b/common/constants.go
index e33a64b2..51b798db 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -1,6 +1,7 @@
package common
import (
+ "crypto/tls"
//"os"
//"strconv"
"sync"
@@ -73,6 +74,9 @@ var MemoryCacheEnabled bool
var LogConsumeEnabled = true
+var TLSInsecureSkipVerify bool
+var InsecureTLSConfig = &tls.Config{InsecureSkipVerify: true}
+
var SMTPServer = ""
var SMTPPort = 587
var SMTPSSLEnabled = false
diff --git a/common/init.go b/common/init.go
index 0789f8cc..9501ce3b 100644
--- a/common/init.go
+++ b/common/init.go
@@ -4,6 +4,7 @@ import (
"flag"
"fmt"
"log"
+ "net/http"
"os"
"path/filepath"
"strconv"
@@ -81,6 +82,16 @@ func InitEnv() {
DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
+ TLSInsecureSkipVerify = GetEnvOrDefaultBool("TLS_INSECURE_SKIP_VERIFY", false)
+ if TLSInsecureSkipVerify {
+ if tr, ok := http.DefaultTransport.(*http.Transport); ok && tr != nil {
+ if tr.TLSClientConfig != nil {
+ tr.TLSClientConfig.InsecureSkipVerify = true
+ } else {
+ tr.TLSClientConfig = InsecureTLSConfig
+ }
+ }
+ }
// Parse requestInterval and set RequestInterval
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
diff --git a/common/utils.go b/common/utils.go
index f63df857..b67fe1c5 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -263,7 +263,7 @@ func GetTimestamp() int64 {
}
func GetTimeString() string {
- now := time.Now()
+ now := time.Now().UTC()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
diff --git a/constant/api_type.go b/constant/api_type.go
index 32b48bcd..536ebd2c 100644
--- a/constant/api_type.go
+++ b/constant/api_type.go
@@ -35,5 +35,6 @@ const (
APITypeSubmodel
APITypeMiniMax
APITypeReplicate
+ APITypeCodex
APITypeDummy // this one is only for count, do not add any channel after this
)
diff --git a/constant/channel.go b/constant/channel.go
index 6d3a5d92..48502bed 100644
--- a/constant/channel.go
+++ b/constant/channel.go
@@ -54,6 +54,7 @@ const (
ChannelTypeDoubaoVideo = 54
ChannelTypeSora = 55
ChannelTypeReplicate = 56
+ ChannelTypeCodex = 57
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -116,6 +117,7 @@ var ChannelBaseURLs = []string{
"https://ark.cn-beijing.volces.com", //54
"https://api.openai.com", //55
"https://api.replicate.com", //56
+ "https://chatgpt.com", //57
}
var ChannelTypeNames = map[int]string{
@@ -172,6 +174,7 @@ var ChannelTypeNames = map[int]string{
ChannelTypeDoubaoVideo: "DoubaoVideo",
ChannelTypeSora: "Sora",
ChannelTypeReplicate: "Replicate",
+ ChannelTypeCodex: "Codex",
}
func GetChannelTypeName(channelType int) string {
diff --git a/controller/channel-test.go b/controller/channel-test.go
index f9657edb..8ebfbdf6 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -193,6 +193,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
+ info.IsChannelTest = true
info.InitChannelMeta(c)
err = helper.ModelMappedHelper(c, info, request)
@@ -309,8 +310,29 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
}
+
+ //jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ //if err != nil {
+ // return testResult{
+ // context: c,
+ // localErr: err,
+ // newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
+ // }
+ //}
+
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid),
+ }
+ }
+ }
+
requestBody := bytes.NewBuffer(jsonData)
- c.Request.Body = io.NopCloser(requestBody)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return testResult{
diff --git a/controller/channel.go b/controller/channel.go
index 9fea9a80..3ac29d7c 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -1,16 +1,19 @@
package controller
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
+ "time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
+ "github.com/QuantumNous/new-api/relay/channel/gemini"
"github.com/QuantumNous/new-api/relay/channel/ollama"
"github.com/QuantumNous/new-api/service"
@@ -260,11 +263,37 @@ func FetchUpstreamModels(c *gin.Context) {
return
}
+ // 对于 Gemini 渠道,使用特殊处理
+ if channel.Type == constant.ChannelTypeGemini {
+ // 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
+ key, _, apiErr := channel.GetNextEnabledKey()
+ if apiErr != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
+ })
+ return
+ }
+ key = strings.TrimSpace(key)
+ models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": models,
+ })
+ return
+ }
+
var url string
switch channel.Type {
- case constant.ChannelTypeGemini:
- // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
- url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
case constant.ChannelTypeZhipu_v4:
@@ -577,9 +606,60 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
}
}
+ // Codex OAuth key validation (optional, only when JSON object is provided)
+ if channel.Type == constant.ChannelTypeCodex {
+ trimmedKey := strings.TrimSpace(channel.Key)
+ if isAdd || trimmedKey != "" {
+ if !strings.HasPrefix(trimmedKey, "{") {
+ return fmt.Errorf("Codex key must be a valid JSON object")
+ }
+ var keyMap map[string]any
+ if err := common.Unmarshal([]byte(trimmedKey), &keyMap); err != nil {
+ return fmt.Errorf("Codex key must be a valid JSON object")
+ }
+ if v, ok := keyMap["access_token"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
+ return fmt.Errorf("Codex key JSON must include access_token")
+ }
+ if v, ok := keyMap["account_id"]; !ok || v == nil || strings.TrimSpace(fmt.Sprintf("%v", v)) == "" {
+ return fmt.Errorf("Codex key JSON must include account_id")
+ }
+ }
+ }
+
return nil
}
+func RefreshCodexChannelCredential(c *gin.Context) {
+ channelId, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
+ defer cancel()
+
+ oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "refreshed",
+ "data": gin.H{
+ "expires_at": oauthKey.Expired,
+ "last_refresh": oauthKey.LastRefresh,
+ "account_id": oauthKey.AccountID,
+ "email": oauthKey.Email,
+ "channel_id": ch.Id,
+ "channel_type": ch.Type,
+ "channel_name": ch.Name,
+ },
+ })
+}
+
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
@@ -1072,6 +1152,23 @@ func FetchModels(c *gin.Context) {
return
}
+ if req.Type == constant.ChannelTypeGemini {
+ models, err := gemini.FetchGeminiModels(baseURL, key, "")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": models,
+ })
+ return
+ }
+
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
diff --git a/controller/codex_oauth.go b/controller/codex_oauth.go
new file mode 100644
index 00000000..3c881ebb
--- /dev/null
+++ b/controller/codex_oauth.go
@@ -0,0 +1,243 @@
+package controller
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/model"
+ "github.com/QuantumNous/new-api/relay/channel/codex"
+ "github.com/QuantumNous/new-api/service"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type codexOAuthCompleteRequest struct {
+ Input string `json:"input"`
+}
+
+func codexOAuthSessionKey(channelID int, field string) string {
+ return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
+}
+
+func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
+ v := strings.TrimSpace(input)
+ if v == "" {
+ return "", "", errors.New("empty input")
+ }
+ if strings.Contains(v, "#") {
+ parts := strings.SplitN(v, "#", 2)
+ code = strings.TrimSpace(parts[0])
+ state = strings.TrimSpace(parts[1])
+ return code, state, nil
+ }
+ if strings.Contains(v, "code=") {
+ u, parseErr := url.Parse(v)
+ if parseErr == nil {
+ q := u.Query()
+ code = strings.TrimSpace(q.Get("code"))
+ state = strings.TrimSpace(q.Get("state"))
+ return code, state, nil
+ }
+ q, parseErr := url.ParseQuery(v)
+ if parseErr == nil {
+ code = strings.TrimSpace(q.Get("code"))
+ state = strings.TrimSpace(q.Get("state"))
+ return code, state, nil
+ }
+ }
+
+ code = v
+ return code, "", nil
+}
+
+func StartCodexOAuth(c *gin.Context) {
+ startCodexOAuthWithChannelID(c, 0)
+}
+
+func StartCodexOAuthForChannel(c *gin.Context) {
+ channelID, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
+ return
+ }
+ startCodexOAuthWithChannelID(c, channelID)
+}
+
+func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
+ if channelID > 0 {
+ ch, err := model.GetChannelById(channelID, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if ch == nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
+ return
+ }
+ if ch.Type != constant.ChannelTypeCodex {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
+ return
+ }
+ }
+
+ flow, err := service.CreateCodexOAuthAuthorizationFlow()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ session := sessions.Default(c)
+ session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
+ session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
+ session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
+ _ = session.Save()
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "authorize_url": flow.AuthorizeURL,
+ },
+ })
+}
+
+func CompleteCodexOAuth(c *gin.Context) {
+ completeCodexOAuthWithChannelID(c, 0)
+}
+
+func CompleteCodexOAuthForChannel(c *gin.Context) {
+ channelID, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
+ return
+ }
+ completeCodexOAuthWithChannelID(c, channelID)
+}
+
+func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
+ req := codexOAuthCompleteRequest{}
+ if err := c.ShouldBindJSON(&req); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ code, state, err := parseCodexAuthorizationInput(req.Input)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ if strings.TrimSpace(code) == "" {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
+ return
+ }
+ if strings.TrimSpace(state) == "" {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
+ return
+ }
+
+ if channelID > 0 {
+ ch, err := model.GetChannelById(channelID, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if ch == nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
+ return
+ }
+ if ch.Type != constant.ChannelTypeCodex {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
+ return
+ }
+ }
+
+ session := sessions.Default(c)
+ expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
+ verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
+ if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
+ return
+ }
+ if state != expectedState {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
+ defer cancel()
+
+ tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
+ if !ok {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
+ return
+ }
+ email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
+
+ key := codex.OAuthKey{
+ AccessToken: tokenRes.AccessToken,
+ RefreshToken: tokenRes.RefreshToken,
+ AccountID: accountID,
+ LastRefresh: time.Now().Format(time.RFC3339),
+ Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
+ Email: email,
+ Type: "codex",
+ }
+ encoded, err := common.Marshal(key)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ session.Delete(codexOAuthSessionKey(channelID, "state"))
+ session.Delete(codexOAuthSessionKey(channelID, "verifier"))
+ session.Delete(codexOAuthSessionKey(channelID, "created_at"))
+ _ = session.Save()
+
+ if channelID > 0 {
+ if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ service.ResetProxyClientCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "saved",
+ "data": gin.H{
+ "channel_id": channelID,
+ "account_id": accountID,
+ "email": email,
+ "expires_at": key.Expired,
+ "last_refresh": key.LastRefresh,
+ },
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "generated",
+ "data": gin.H{
+ "key": string(encoded),
+ "account_id": accountID,
+ "email": email,
+ "expires_at": key.Expired,
+ "last_refresh": key.LastRefresh,
+ },
+ })
+}
diff --git a/controller/codex_usage.go b/controller/codex_usage.go
new file mode 100644
index 00000000..61614b46
--- /dev/null
+++ b/controller/codex_usage.go
@@ -0,0 +1,124 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/model"
+ "github.com/QuantumNous/new-api/relay/channel/codex"
+ "github.com/QuantumNous/new-api/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetCodexChannelUsage(c *gin.Context) {
+ channelId, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
+ return
+ }
+
+ ch, err := model.GetChannelById(channelId, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if ch == nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
+ return
+ }
+ if ch.Type != constant.ChannelTypeCodex {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
+ return
+ }
+ if ch.ChannelInfo.IsMultiKey {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"})
+ return
+ }
+
+ oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ accessToken := strings.TrimSpace(oauthKey.AccessToken)
+ accountID := strings.TrimSpace(oauthKey.AccountID)
+ if accessToken == "" {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"})
+ return
+ }
+ if accountID == "" {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"})
+ return
+ }
+
+ client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
+ defer cancel()
+
+ statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" {
+ refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
+ defer refreshCancel()
+
+ res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+ if refreshErr == nil {
+ oauthKey.AccessToken = res.AccessToken
+ oauthKey.RefreshToken = res.RefreshToken
+ oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
+ oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
+ if strings.TrimSpace(oauthKey.Type) == "" {
+ oauthKey.Type = "codex"
+ }
+
+ encoded, encErr := common.Marshal(oauthKey)
+ if encErr == nil {
+ _ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error
+ model.InitChannelCache()
+ service.ResetProxyClientCache()
+ }
+
+ ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second)
+ defer cancel2()
+ statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ }
+ }
+
+ var payload any
+ if json.Unmarshal(body, &payload) != nil {
+ payload = string(body)
+ }
+
+ ok := statusCode >= 200 && statusCode < 300
+ resp := gin.H{
+ "success": ok,
+ "message": "",
+ "upstream_status": statusCode,
+ "data": payload,
+ }
+ if !ok {
+ resp["message"] = fmt.Sprintf("upstream status: %d", statusCode)
+ }
+ c.JSON(http.StatusOK, resp)
+}
diff --git a/controller/model_sync.go b/controller/model_sync.go
index b2ac99da..737f92d4 100644
--- a/controller/model_sync.go
+++ b/controller/model_sync.go
@@ -99,6 +99,9 @@ func newHTTPClient() *http.Client {
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second,
}
+ if common.TLSInsecureSkipVerify {
+ transport.TLSClientConfig = common.InsecureTLSConfig
+ }
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
@@ -115,7 +118,17 @@ func newHTTPClient() *http.Client {
return &http.Client{Transport: transport}
}
-var httpClient = newHTTPClient()
+var (
+ httpClientOnce sync.Once
+ httpClient *http.Client
+)
+
+func getHTTPClient() *http.Client {
+ httpClientOnce.Do(func() {
+ httpClient = newHTTPClient()
+ })
+ return httpClient
+}
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
var lastErr error
@@ -138,7 +151,7 @@ func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T])
}
cacheMutex.RUnlock()
- resp, err := httpClient.Do(req)
+ resp, err := getHTTPClient().Do(req)
if err != nil {
lastErr = err
// backoff with jitter
diff --git a/controller/option.go b/controller/option.go
index 4d5b4e8d..959f2f9b 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -10,6 +10,7 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/console_setting"
+ "github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/setting/system_setting"
@@ -177,6 +178,24 @@ func UpdateOption(c *gin.Context) {
})
return
}
+ case "AutomaticDisableStatusCodes":
+ _, err = operation_setting.ParseHTTPStatusCodeRanges(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "AutomaticRetryStatusCodes":
+ _, err = operation_setting.ParseHTTPStatusCodeRanges(option.Value.(string))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
case "console_setting.api_info":
err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
if err != nil {
diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go
index b8224b81..0b6a6dff 100644
--- a/controller/ratio_sync.go
+++ b/controller/ratio_sync.go
@@ -11,6 +11,7 @@ import (
"sync"
"time"
+ "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/dto"
@@ -110,6 +111,9 @@ func FetchUpstreamRatios(c *gin.Context) {
dialer := &net.Dialer{Timeout: 10 * time.Second}
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
+ if common.TLSInsecureSkipVerify {
+ transport.TLSClientConfig = common.InsecureTLSConfig
+ }
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
diff --git a/controller/relay.go b/controller/relay.go
index 9759fa30..4fba947f 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -21,6 +21,7 @@ import (
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting"
+ "github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
"github.com/bytedance/gopkg/util/gopool"
@@ -316,30 +317,14 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
- if openaiErr.StatusCode == http.StatusTooManyRequests {
- return true
- }
- if openaiErr.StatusCode == 307 {
- return true
- }
- if openaiErr.StatusCode/100 == 5 {
- // 超时不重试
- if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
- return false
- }
- return true
- }
- if openaiErr.StatusCode == http.StatusBadRequest {
+ code := openaiErr.StatusCode
+ if code >= 200 && code < 300 {
return false
}
- if openaiErr.StatusCode == 408 {
- // azure处理超时不重试
- return false
+ if code < 100 || code > 599 {
+ return true
}
- if openaiErr.StatusCode/100 == 2 {
- return false
- }
- return true
+ return operation_setting.ShouldRetryByStatusCode(code)
}
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
@@ -348,7 +333,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
gopool.Go(func() {
- service.DisableChannel(channelError, err.Error())
+ service.DisableChannel(channelError, err.ErrorWithStatusCode())
})
}
@@ -378,7 +363,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
- model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
+ model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other)
}
}
diff --git a/dto/error.go b/dto/error.go
index 78197765..be57407f 100644
--- a/dto/error.go
+++ b/dto/error.go
@@ -26,7 +26,8 @@ type GeneralErrorResponse struct {
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
- Metadata json.RawMessage `json:"metadata,omitempty"`
+ Metadata json.RawMessage `json:"metadata,omitempty"`
+ Detail string `json:"detail,omitempty"`
Header struct {
Message string `json:"message"`
} `json:"header"`
@@ -79,6 +80,9 @@ func (e GeneralErrorResponse) ToMessage() string {
if e.ErrorMsg != "" {
return e.ErrorMsg
}
+ if e.Detail != "" {
+ return e.Detail
+ }
if e.Header.Message != "" {
return e.Header.Message
}
diff --git a/dto/gemini.go b/dto/gemini.go
index 7c5969ef..17881c52 100644
--- a/dto/gemini.go
+++ b/dto/gemini.go
@@ -341,6 +341,88 @@ type GeminiChatGenerationConfig struct {
ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
}
+// UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields.
+func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
+ type Alias GeminiChatGenerationConfig
+ var aux struct {
+ Alias
+ TopPSnake float64 `json:"top_p,omitempty"`
+ TopKSnake float64 `json:"top_k,omitempty"`
+ MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
+ CandidateCountSnake int `json:"candidate_count,omitempty"`
+ StopSequencesSnake []string `json:"stop_sequences,omitempty"`
+ ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
+ ResponseSchemaSnake any `json:"response_schema,omitempty"`
+ ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
+ PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
+ FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
+ ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
+ MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
+ ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
+ ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"`
+ SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"`
+ ImageConfigSnake json.RawMessage `json:"image_config,omitempty"`
+ }
+
+ if err := common.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ *c = GeminiChatGenerationConfig(aux.Alias)
+
+ // Prioritize snake_case if present
+ if aux.TopPSnake != 0 {
+ c.TopP = aux.TopPSnake
+ }
+ if aux.TopKSnake != 0 {
+ c.TopK = aux.TopKSnake
+ }
+ if aux.MaxOutputTokensSnake != 0 {
+ c.MaxOutputTokens = aux.MaxOutputTokensSnake
+ }
+ if aux.CandidateCountSnake != 0 {
+ c.CandidateCount = aux.CandidateCountSnake
+ }
+ if len(aux.StopSequencesSnake) > 0 {
+ c.StopSequences = aux.StopSequencesSnake
+ }
+ if aux.ResponseMimeTypeSnake != "" {
+ c.ResponseMimeType = aux.ResponseMimeTypeSnake
+ }
+ if aux.ResponseSchemaSnake != nil {
+ c.ResponseSchema = aux.ResponseSchemaSnake
+ }
+ if len(aux.ResponseJsonSchemaSnake) > 0 {
+ c.ResponseJsonSchema = aux.ResponseJsonSchemaSnake
+ }
+ if aux.PresencePenaltySnake != nil {
+ c.PresencePenalty = aux.PresencePenaltySnake
+ }
+ if aux.FrequencyPenaltySnake != nil {
+ c.FrequencyPenalty = aux.FrequencyPenaltySnake
+ }
+ if aux.ResponseLogprobsSnake {
+ c.ResponseLogprobs = aux.ResponseLogprobsSnake
+ }
+ if aux.MediaResolutionSnake != "" {
+ c.MediaResolution = aux.MediaResolutionSnake
+ }
+ if len(aux.ResponseModalitiesSnake) > 0 {
+ c.ResponseModalities = aux.ResponseModalitiesSnake
+ }
+ if aux.ThinkingConfigSnake != nil {
+ c.ThinkingConfig = aux.ThinkingConfigSnake
+ }
+ if len(aux.SpeechConfigSnake) > 0 {
+ c.SpeechConfig = aux.SpeechConfigSnake
+ }
+ if len(aux.ImageConfigSnake) > 0 {
+ c.ImageConfig = aux.ImageConfigSnake
+ }
+
+ return nil
+}
+
type MediaResolution string
type GeminiChatCandidate struct {
diff --git a/dto/openai_request.go b/dto/openai_request.go
index 232a1ae1..89ebcf14 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -808,11 +808,11 @@ type OpenAIResponsesRequest struct {
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
Stream bool `json:"stream,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
- TopP float64 `json:"top_p,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
Truncation string `json:"truncation,omitempty"`
User string `json:"user,omitempty"`
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
diff --git a/dto/openai_response.go b/dto/openai_response.go
index 6baee78c..19ca9290 100644
--- a/dto/openai_response.go
+++ b/dto/openai_response.go
@@ -334,13 +334,16 @@ type IncompleteDetails struct {
}
type ResponsesOutput struct {
- Type string `json:"type"`
- ID string `json:"id"`
- Status string `json:"status"`
- Role string `json:"role"`
- Content []ResponsesOutputContent `json:"content"`
- Quality string `json:"quality"`
- Size string `json:"size"`
+ Type string `json:"type"`
+ ID string `json:"id"`
+ Status string `json:"status"`
+ Role string `json:"role"`
+ Content []ResponsesOutputContent `json:"content"`
+ Quality string `json:"quality"`
+ Size string `json:"size"`
+ CallId string `json:"call_id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Arguments string `json:"arguments,omitempty"`
}
type ResponsesOutputContent struct {
@@ -369,6 +372,10 @@ type ResponsesStreamResponse struct {
Response *OpenAIResponsesResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"`
+ // - response.function_call_arguments.delta
+ // - response.function_call_arguments.done
+ OutputIndex *int `json:"output_index,omitempty"`
+ ItemID string `json:"item_id,omitempty"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
diff --git a/dto/values.go b/dto/values.go
new file mode 100644
index 00000000..860d5fae
--- /dev/null
+++ b/dto/values.go
@@ -0,0 +1,55 @@
+package dto
+
+import (
+ "encoding/json"
+ "strconv"
+)
+
+type IntValue int
+
+func (i *IntValue) UnmarshalJSON(b []byte) error {
+ var n int
+ if err := json.Unmarshal(b, &n); err == nil {
+ *i = IntValue(n)
+ return nil
+ }
+ var s string
+ if err := json.Unmarshal(b, &s); err != nil {
+ return err
+ }
+ v, err := strconv.Atoi(s)
+ if err != nil {
+ return err
+ }
+ *i = IntValue(v)
+ return nil
+}
+
+func (i IntValue) MarshalJSON() ([]byte, error) {
+ return json.Marshal(int(i))
+}
+
+type BoolValue bool
+
+func (b *BoolValue) UnmarshalJSON(data []byte) error {
+ var boolean bool
+ if err := json.Unmarshal(data, &boolean); err == nil {
+ *b = BoolValue(boolean)
+ return nil
+ }
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ if str == "true" {
+ *b = BoolValue(true)
+ } else if str == "false" {
+ *b = BoolValue(false)
+ } else {
+ return json.Unmarshal(data, &boolean)
+ }
+ return nil
+}
+func (b BoolValue) MarshalJSON() ([]byte, error) {
+ return json.Marshal(bool(b))
+}
diff --git a/main.go b/main.go
index 4c0fc8c6..1326b122 100644
--- a/main.go
+++ b/main.go
@@ -102,6 +102,9 @@ func main() {
go controller.AutomaticallyTestChannels()
+ // Codex credential auto-refresh check every 10 minutes, refresh when expires within 1 day
+ service.StartCodexCredentialAutoRefreshTask()
+
if common.IsMasterNode && constant.UpdateTask {
gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk()
diff --git a/middleware/auth.go b/middleware/auth.go
index 85c46e28..a5d283d2 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -13,6 +13,7 @@ import (
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
+ "github.com/QuantumNous/new-api/types"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
@@ -195,7 +196,7 @@ func TokenAuth() func(c *gin.Context) {
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
- // 检查path包含/v1/messages 或 /v1/models
+ // 检查path包含/v1/messages 或 /v1/models
if strings.Contains(c.Request.URL.Path, "/v1/messages") || strings.Contains(c.Request.URL.Path, "/v1/models") {
anthropicKey := c.Request.Header.Get("x-api-key")
if anthropicKey != "" {
@@ -256,7 +257,7 @@ func TokenAuth() func(c *gin.Context) {
return
}
if common.IsIpInCIDRList(ip, allowIps) == false {
- abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
+ abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中", types.ErrorCodeAccessDenied)
return
}
logger.LogDebug(c, "Client IP %s passed the token IP restrictions check", clientIp)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index a3340472..95fa64a3 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -114,11 +114,11 @@ func Distribute() func(c *gin.Context) {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员"
//}
- abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, types.ErrorCodeModelNotFound)
return
}
if channel == nil {
- abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", usingGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", usingGroup, modelRequest.Model), types.ErrorCodeModelNotFound)
return
}
}
diff --git a/middleware/utils.go b/middleware/utils.go
index 24caa83c..f198af81 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -5,13 +5,14 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
+ "github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
-func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...types.ErrorCode) {
codeStr := ""
if len(code) > 0 {
- codeStr = code[0]
+ codeStr = string(code[0])
}
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{
diff --git a/model/log.go b/model/log.go
index 7495d647..f8940c15 100644
--- a/model/log.go
+++ b/model/log.go
@@ -56,8 +56,9 @@ func formatUserLogs(logs []*Log) {
var otherMap map[string]interface{}
otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil {
- // delete admin
+ // Remove admin-only debug fields.
delete(otherMap, "admin_info")
+ delete(otherMap, "request_conversion")
}
logs[i].Other = common.MapToJsonStr(otherMap)
logs[i].Id = logs[i].Id % 1024
diff --git a/model/option.go b/model/option.go
index e9fd50d7..e268cf57 100644
--- a/model/option.go
+++ b/model/option.go
@@ -143,6 +143,8 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+ common.OptionMap["AutomaticDisableStatusCodes"] = operation_setting.AutomaticDisableStatusCodesToString()
+ common.OptionMap["AutomaticRetryStatusCodes"] = operation_setting.AutomaticRetryStatusCodesToString()
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
@@ -444,6 +446,10 @@ func updateOptionMap(key string, value string) (err error) {
setting.SensitiveWordsFromString(value)
case "AutomaticDisableKeywords":
operation_setting.AutomaticDisableKeywordsFromString(value)
+ case "AutomaticDisableStatusCodes":
+ err = operation_setting.AutomaticDisableStatusCodesFromString(value)
+ case "AutomaticRetryStatusCodes":
+ err = operation_setting.AutomaticRetryStatusCodesFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods":
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index 751a4538..23ef5f4b 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -13,6 +13,8 @@ import (
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
+ "github.com/QuantumNous/new-api/setting/model_setting"
+ "github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -22,6 +24,18 @@ type Adaptor struct {
IsSyncImageModel bool
}
+/*
+ var syncModels = []string{
+ "z-image",
+ "qwen-image",
+ "wan2.6",
+ }
+*/
+func supportsAliAnthropicMessages(modelName string) bool {
+ // Only models with the "qwen" designation can use the Claude-compatible interface; others require conversion.
+ return strings.Contains(strings.ToLower(modelName), "qwen")
+}
+
var syncModels = []string{
"z-image",
"qwen-image",
@@ -29,12 +43,7 @@ var syncModels = []string{
}
func isSyncImageModel(modelName string) bool {
- for _, m := range syncModels {
- if strings.Contains(modelName, m) {
- return true
- }
- }
- return false
+ return model_setting.IsSyncImageModel(modelName)
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -43,7 +52,18 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- return req, nil
+ if supportsAliAnthropicMessages(info.UpstreamModelName) {
+ return req, nil
+ }
+
+ oaiReq, err := service.ClaudeToOpenAIRequest(*req, info)
+ if err != nil {
+ return nil, err
+ }
+ if info.SupportStreamOptions && info.IsStream {
+ oaiReq.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
+ }
+ return a.ConvertOpenAIRequest(c, info, oaiReq)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -53,7 +73,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayFormat {
case types.RelayFormatClaude:
- fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
+ if supportsAliAnthropicMessages(info.UpstreamModelName) {
+ fullRequestURL = fmt.Sprintf("%s/apps/anthropic/v1/messages", info.ChannelBaseUrl)
+ } else {
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
+ }
default:
switch info.RelayMode {
case constant.RelayModeEmbeddings:
@@ -197,11 +221,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case types.RelayFormatClaude:
- if info.IsStream {
- return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
- } else {
+ if supportsAliAnthropicMessages(info.UpstreamModelName) {
+ if info.IsStream {
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ }
+
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
+
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
default:
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index 1ff1e239..128e9453 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -71,6 +71,12 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return nil, fmt.Errorf("new request failed: %w", err)
}
headers := req.Header
+ err = a.SetupRequestHeader(c, &headers, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ // 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
+ // 这样可以覆盖默认的 Authorization header 设置
headerOverride, err := processHeaderOverride(info)
if err != nil {
return nil, err
@@ -78,10 +84,6 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
for key, value := range headerOverride {
headers.Set(key, value)
}
- err = a.SetupRequestHeader(c, &headers, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
resp, err := doRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
@@ -104,6 +106,12 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
headers := req.Header
+ err = a.SetupRequestHeader(c, &headers, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ // 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
+ // 这样可以覆盖默认的 Authorization header 设置
headerOverride, err := processHeaderOverride(info)
if err != nil {
return nil, err
@@ -111,10 +119,6 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
for key, value := range headerOverride {
headers.Set(key, value)
}
- err = a.SetupRequestHeader(c, &headers, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
resp, err := doRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
@@ -128,6 +132,12 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return nil, fmt.Errorf("get request url failed: %w", err)
}
targetHeader := http.Header{}
+ err = a.SetupRequestHeader(c, &targetHeader, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ // 在 SetupRequestHeader 之后应用 Header Override,确保用户设置优先级最高
+ // 这样可以覆盖默认的 Authorization header 设置
headerOverride, err := processHeaderOverride(info)
if err != nil {
return nil, err
@@ -135,10 +145,6 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
for key, value := range headerOverride {
targetHeader.Set(key, value)
}
- err = a.SetupRequestHeader(c, &targetHeader, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
if err != nil {
diff --git a/relay/channel/codex/adaptor.go b/relay/channel/codex/adaptor.go
new file mode 100644
index 00000000..ab61dfac
--- /dev/null
+++ b/relay/channel/codex/adaptor.go
@@ -0,0 +1,164 @@
+package codex
+
+import (
+ "encoding/json"
+ "errors"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/relay/channel"
+ "github.com/QuantumNous/new-api/relay/channel/openai"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ relayconstant "github.com/QuantumNous/new-api/relay/constant"
+ "github.com/QuantumNous/new-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return nil, errors.New("codex channel: endpoint not supported")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ if info != nil && info.ChannelSetting.SystemPrompt != "" {
+ systemPrompt := info.ChannelSetting.SystemPrompt
+
+ if len(request.Instructions) == 0 {
+ if b, err := common.Marshal(systemPrompt); err == nil {
+ request.Instructions = b
+ } else {
+ return nil, err
+ }
+ } else if info.ChannelSetting.SystemPromptOverride {
+ var existing string
+ if err := common.Unmarshal(request.Instructions, &existing); err == nil {
+ existing = strings.TrimSpace(existing)
+ if existing == "" {
+ if b, err := common.Marshal(systemPrompt); err == nil {
+ request.Instructions = b
+ } else {
+ return nil, err
+ }
+ } else {
+ if b, err := common.Marshal(systemPrompt + "\n" + existing); err == nil {
+ request.Instructions = b
+ } else {
+ return nil, err
+ }
+ }
+ } else {
+ if b, err := common.Marshal(systemPrompt); err == nil {
+ request.Instructions = b
+ } else {
+ return nil, err
+ }
+ }
+ }
+ }
+
+ // codex: store must be false
+ request.Store = json.RawMessage("false")
+ // rm max_output_tokens
+ request.MaxOutputTokens = 0
+ request.Temperature = nil
+ return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode != relayconstant.RelayModeResponses {
+ return nil, types.NewError(errors.New("codex channel: endpoint not supported"), types.ErrorCodeInvalidRequest)
+ }
+
+ if info.IsStream {
+ return openai.OaiResponsesStreamHandler(c, info, resp)
+ }
+ return openai.OaiResponsesHandler(c, info, resp)
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if info.RelayMode != relayconstant.RelayModeResponses {
+ return "", errors.New("codex channel: only /v1/responses is supported")
+ }
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, "/backend-api/codex/responses", info.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+
+ key := strings.TrimSpace(info.ApiKey)
+ if !strings.HasPrefix(key, "{") {
+ return errors.New("codex channel: key must be a JSON object")
+ }
+
+ oauthKey, err := ParseOAuthKey(key)
+ if err != nil {
+ return err
+ }
+
+ accessToken := strings.TrimSpace(oauthKey.AccessToken)
+ accountID := strings.TrimSpace(oauthKey.AccountID)
+
+ if accessToken == "" {
+ return errors.New("codex channel: access_token is required")
+ }
+ if accountID == "" {
+ return errors.New("codex channel: account_id is required")
+ }
+
+ req.Set("Authorization", "Bearer "+accessToken)
+ req.Set("chatgpt-account-id", accountID)
+
+ if req.Get("OpenAI-Beta") == "" {
+ req.Set("OpenAI-Beta", "responses=experimental")
+ }
+ if req.Get("originator") == "" {
+ req.Set("originator", "codex_cli_rs")
+ }
+
+ return nil
+}
diff --git a/relay/channel/codex/constants.go b/relay/channel/codex/constants.go
new file mode 100644
index 00000000..461e033a
--- /dev/null
+++ b/relay/channel/codex/constants.go
@@ -0,0 +1,9 @@
+package codex
+
+var ModelList = []string{
+ "gpt-5", "gpt-5-codex", "gpt-5-codex-mini",
+ "gpt-5.1", "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini",
+ "gpt-5.2", "gpt-5.2-codex",
+}
+
+const ChannelName = "codex"
diff --git a/relay/channel/codex/oauth_key.go b/relay/channel/codex/oauth_key.go
new file mode 100644
index 00000000..bf143f81
--- /dev/null
+++ b/relay/channel/codex/oauth_key.go
@@ -0,0 +1,30 @@
+package codex
+
+import (
+ "errors"
+
+ "github.com/QuantumNous/new-api/common"
+)
+
+type OAuthKey struct {
+ IDToken string `json:"id_token,omitempty"`
+ AccessToken string `json:"access_token,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+
+ AccountID string `json:"account_id,omitempty"`
+ LastRefresh string `json:"last_refresh,omitempty"`
+ Email string `json:"email,omitempty"`
+ Type string `json:"type,omitempty"`
+ Expired string `json:"expired,omitempty"`
+}
+
+func ParseOAuthKey(raw string) (*OAuthKey, error) {
+ if raw == "" {
+ return nil, errors.New("codex channel: empty oauth key")
+ }
+ var key OAuthKey
+ if err := common.Unmarshal([]byte(raw), &key); err != nil {
+ return nil, errors.New("codex channel: invalid oauth key json")
+ }
+ return &key, nil
+}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index 4d93027f..49c4da8d 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -1,6 +1,7 @@
package gemini
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -8,6 +9,7 @@ import (
"net/http"
"strconv"
"strings"
+ "time"
"unicode/utf8"
"github.com/QuantumNous/new-api/common"
@@ -653,101 +655,84 @@ func getSupportedMimeTypesList() []string {
return keys
}
+var geminiOpenAPISchemaAllowedFields = map[string]struct{}{
+ "anyOf": {},
+ "default": {},
+ "description": {},
+ "enum": {},
+ "example": {},
+ "format": {},
+ "items": {},
+ "maxItems": {},
+ "maxLength": {},
+ "maxProperties": {},
+ "maximum": {},
+ "minItems": {},
+ "minLength": {},
+ "minProperties": {},
+ "minimum": {},
+ "nullable": {},
+ "pattern": {},
+ "properties": {},
+ "propertyOrdering": {},
+ "required": {},
+ "title": {},
+ "type": {},
+}
+
+const geminiFunctionSchemaMaxDepth = 64
+
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
+ return cleanFunctionParametersWithDepth(params, 0)
+}
+
+func cleanFunctionParametersWithDepth(params interface{}, depth int) interface{} {
if params == nil {
return nil
}
+ if depth >= geminiFunctionSchemaMaxDepth {
+ return cleanFunctionParametersShallow(params)
+ }
+
switch v := params.(type) {
case map[string]interface{}:
- // Create a copy to avoid modifying the original
- cleanedMap := make(map[string]interface{})
+ // Keep only Gemini-supported OpenAPI schema subset fields (per official SDK Schema).
+ cleanedMap := make(map[string]interface{}, len(v))
for k, val := range v {
- cleanedMap[k] = val
- }
-
- // Remove unsupported root-level fields
- delete(cleanedMap, "default")
- delete(cleanedMap, "exclusiveMaximum")
- delete(cleanedMap, "exclusiveMinimum")
- delete(cleanedMap, "$schema")
- delete(cleanedMap, "additionalProperties")
-
- // Check and clean 'format' for string types
- if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
- if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
- if formatValue != "enum" && formatValue != "date-time" {
- delete(cleanedMap, "format")
- }
+ if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
+ cleanedMap[k] = val
}
}
+ normalizeGeminiSchemaTypeAndNullable(cleanedMap)
+
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
- cleanedProps[propName] = cleanFunctionParameters(propValue)
+ cleanedProps[propName] = cleanFunctionParametersWithDepth(propValue, depth+1)
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
- cleanedMap["items"] = cleanFunctionParameters(items)
+ cleanedMap["items"] = cleanFunctionParametersWithDepth(items, depth+1)
}
- // Also handle items if it's an array of schemas
- if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
- cleanedItemsArray := make([]interface{}, len(itemsArray))
- for i, item := range itemsArray {
- cleanedItemsArray[i] = cleanFunctionParameters(item)
- }
- cleanedMap["items"] = cleanedItemsArray
+ // OpenAPI tuple-style items is not supported by Gemini SDK Schema; keep first to avoid API rejection.
+ if itemsArray, ok := cleanedMap["items"].([]interface{}); ok && len(itemsArray) > 0 {
+ cleanedMap["items"] = cleanFunctionParametersWithDepth(itemsArray[0], depth+1)
}
- // Recursively clean other schema composition keywords
- for _, field := range []string{"allOf", "anyOf", "oneOf"} {
- if nested, ok := cleanedMap[field].([]interface{}); ok {
- cleanedNested := make([]interface{}, len(nested))
- for i, item := range nested {
- cleanedNested[i] = cleanFunctionParameters(item)
- }
- cleanedMap[field] = cleanedNested
- }
- }
-
- // Recursively clean patternProperties
- if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
- cleanedPatternProps := make(map[string]interface{})
- for pattern, schema := range patternProps {
- cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
- }
- cleanedMap["patternProperties"] = cleanedPatternProps
- }
-
- // Recursively clean definitions
- if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
- cleanedDefinitions := make(map[string]interface{})
- for defName, defSchema := range definitions {
- cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
- }
- cleanedMap["definitions"] = cleanedDefinitions
- }
-
- // Recursively clean $defs (newer JSON Schema draft)
- if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
- cleanedDefs := make(map[string]interface{})
- for defName, defSchema := range defs {
- cleanedDefs[defName] = cleanFunctionParameters(defSchema)
- }
- cleanedMap["$defs"] = cleanedDefs
- }
-
- // Clean conditional keywords
- for _, field := range []string{"if", "then", "else", "not"} {
- if nested, ok := cleanedMap[field]; ok {
- cleanedMap[field] = cleanFunctionParameters(nested)
+ // Recursively clean anyOf
+ if nested, ok := cleanedMap["anyOf"].([]interface{}); ok && nested != nil {
+ cleanedNested := make([]interface{}, len(nested))
+ for i, item := range nested {
+ cleanedNested[i] = cleanFunctionParametersWithDepth(item, depth+1)
}
+ cleanedMap["anyOf"] = cleanedNested
}
return cleanedMap
@@ -756,7 +741,7 @@ func cleanFunctionParameters(params interface{}) interface{} {
// Handle arrays of schemas
cleanedArray := make([]interface{}, len(v))
for i, item := range v {
- cleanedArray[i] = cleanFunctionParameters(item)
+ cleanedArray[i] = cleanFunctionParametersWithDepth(item, depth+1)
}
return cleanedArray
@@ -766,6 +751,91 @@ func cleanFunctionParameters(params interface{}) interface{} {
}
}
+func cleanFunctionParametersShallow(params interface{}) interface{} {
+ switch v := params.(type) {
+ case map[string]interface{}:
+ cleanedMap := make(map[string]interface{}, len(v))
+ for k, val := range v {
+ if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
+ cleanedMap[k] = val
+ }
+ }
+ normalizeGeminiSchemaTypeAndNullable(cleanedMap)
+ // Stop recursion and avoid retaining huge nested structures.
+ delete(cleanedMap, "properties")
+ delete(cleanedMap, "items")
+ delete(cleanedMap, "anyOf")
+ return cleanedMap
+ case []interface{}:
+ // Prefer an empty list over deep recursion on attacker-controlled inputs.
+ return []interface{}{}
+ default:
+ return params
+ }
+}
+
+func normalizeGeminiSchemaTypeAndNullable(schema map[string]interface{}) {
+ rawType, ok := schema["type"]
+ if !ok || rawType == nil {
+ return
+ }
+
+ normalize := func(t string) (string, bool) {
+ switch strings.ToLower(strings.TrimSpace(t)) {
+ case "object":
+ return "OBJECT", false
+ case "array":
+ return "ARRAY", false
+ case "string":
+ return "STRING", false
+ case "integer":
+ return "INTEGER", false
+ case "number":
+ return "NUMBER", false
+ case "boolean":
+ return "BOOLEAN", false
+ case "null":
+ return "", true
+ default:
+ return t, false
+ }
+ }
+
+ switch t := rawType.(type) {
+ case string:
+ normalized, isNull := normalize(t)
+ if isNull {
+ schema["nullable"] = true
+ delete(schema, "type")
+ return
+ }
+ schema["type"] = normalized
+ case []interface{}:
+ nullable := false
+ var chosen string
+ for _, item := range t {
+ if s, ok := item.(string); ok {
+ normalized, isNull := normalize(s)
+ if isNull {
+ nullable = true
+ continue
+ }
+ if chosen == "" {
+ chosen = normalized
+ }
+ }
+ }
+ if nullable {
+ schema["nullable"] = true
+ }
+ if chosen != "" {
+ schema["type"] = chosen
+ } else {
+ delete(schema, "type")
+ }
+ }
+}
+
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
@@ -1138,6 +1208,8 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
finishReason := constant.FinishReasonStop
+ toolCallIndexByChoice := make(map[int]map[string]int)
+ nextToolCallIndexByChoice := make(map[int]int)
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
@@ -1145,6 +1217,28 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
+ for choiceIdx := range response.Choices {
+ choiceKey := response.Choices[choiceIdx].Index
+ for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
+ tool := &response.Choices[choiceIdx].Delta.ToolCalls[toolIdx]
+ if tool.ID == "" {
+ continue
+ }
+ m := toolCallIndexByChoice[choiceKey]
+ if m == nil {
+ m = make(map[string]int)
+ toolCallIndexByChoice[choiceKey] = m
+ }
+ if idx, ok := m[tool.ID]; ok {
+ tool.SetIndex(idx)
+ continue
+ }
+ idx := nextToolCallIndexByChoice[choiceKey]
+ nextToolCallIndexByChoice[choiceKey] = idx + 1
+ m[tool.ID] = idx
+ tool.SetIndex(idx)
+ }
+ }
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
if info.SendResponseCount == 0 {
@@ -1363,3 +1457,76 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
return usage, nil
}
+
+type GeminiModelsResponse struct {
+ Models []dto.GeminiModel `json:"models"`
+ NextPageToken string `json:"nextPageToken"`
+}
+
+func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
+ client, err := service.GetHttpClientWithProxy(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
+ }
+
+ allModels := make([]string, 0)
+ nextPageToken := ""
+ maxPages := 100 // Safety limit to prevent infinite loops
+
+ for page := 0; page < maxPages; page++ {
+ url := fmt.Sprintf("%s/v1beta/models", baseURL)
+ if nextPageToken != "" {
+ url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("创建请求失败: %v", err)
+ }
+
+ request.Header.Set("x-goog-api-key", apiKey)
+
+ response, err := client.Do(request)
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("请求失败: %v", err)
+ }
+
+ if response.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(response.Body)
+ response.Body.Close()
+ cancel()
+ return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
+ }
+
+ body, err := io.ReadAll(response.Body)
+ response.Body.Close()
+ cancel()
+ if err != nil {
+ return nil, fmt.Errorf("读取响应失败: %v", err)
+ }
+
+ var modelsResponse GeminiModelsResponse
+ if err = common.Unmarshal(body, &modelsResponse); err != nil {
+ return nil, fmt.Errorf("解析响应失败: %v", err)
+ }
+
+ for _, model := range modelsResponse.Models {
+ modelNameValue, ok := model.Name.(string)
+ if !ok {
+ continue
+ }
+ modelName := strings.TrimPrefix(modelNameValue, "models/")
+ allModels = append(allModels, modelName)
+ }
+
+ nextPageToken = modelsResponse.NextPageToken
+ if nextPageToken == "" {
+ break
+ }
+ }
+
+ return allModels, nil
+}
diff --git a/relay/channel/minimax/constants.go b/relay/channel/minimax/constants.go
index df420d4b..91193c4b 100644
--- a/relay/channel/minimax/constants.go
+++ b/relay/channel/minimax/constants.go
@@ -14,6 +14,9 @@ var ModelList = []string{
"speech-02-turbo",
"speech-01-hd",
"speech-01-turbo",
+ "MiniMax-M2.1",
+ "MiniMax-M2.1-lightning",
+ "MiniMax-M2",
}
var ChannelName = "minimax"
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index f40f5da6..c031fd75 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -187,6 +187,17 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
+ // 检查 Header Override 是否已设置 Authorization,如果已设置则跳过默认设置
+ // 这样可以避免在 Header Override 应用时被覆盖(虽然 Header Override 会在之后应用,但这里作为额外保护)
+ hasAuthOverride := false
+ if len(info.HeadersOverride) > 0 {
+ for k := range info.HeadersOverride {
+ if strings.EqualFold(k, "Authorization") {
+ hasAuthOverride = true
+ break
+ }
+ }
+ }
if info.RelayMode == relayconstant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
@@ -201,10 +212,14 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
- header.Set("Authorization", "Bearer "+info.ApiKey)
+ if !hasAuthOverride {
+ header.Set("Authorization", "Bearer "+info.ApiKey)
+ }
}
} else {
- header.Set("Authorization", "Bearer "+info.ApiKey)
+ if !hasAuthOverride {
+ header.Set("Authorization", "Bearer "+info.ApiKey)
+ }
}
if info.ChannelType == constant.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://www.newapi.ai")
diff --git a/relay/channel/openai/chat_via_responses.go b/relay/channel/openai/chat_via_responses.go
new file mode 100644
index 00000000..83f9734c
--- /dev/null
+++ b/relay/channel/openai/chat_via_responses.go
@@ -0,0 +1,369 @@
+package openai
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/logger"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ "github.com/QuantumNous/new-api/relay/helper"
+ "github.com/QuantumNous/new-api/service"
+ "github.com/QuantumNous/new-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ if resp == nil || resp.Body == nil {
+ return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ }
+
+ defer service.CloseResponseBodyGracefully(resp)
+
+ var responsesResp dto.OpenAIResponsesResponse
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
+ }
+
+ if err := common.Unmarshal(body, &responsesResp); err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ if oaiError := responsesResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
+ return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
+ }
+
+ chatId := helper.GetResponseID(c)
+ chatResp, usage, err := service.ResponsesResponseToChatCompletionsResponse(&responsesResp, chatId)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ if usage == nil || usage.TotalTokens == 0 {
+ text := service.ExtractOutputTextFromResponses(&responsesResp)
+ usage = service.ResponseText2Usage(c, text, info.UpstreamModelName, info.GetEstimatePromptTokens())
+ chatResp.Usage = *usage
+ }
+
+ chatBody, err := common.Marshal(chatResp)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
+ }
+
+ service.IOCopyBytesGracefully(c, resp, chatBody)
+ return usage, nil
+}
+
+func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ if resp == nil || resp.Body == nil {
+ return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ }
+
+ defer service.CloseResponseBodyGracefully(resp)
+
+ responseId := helper.GetResponseID(c)
+ createAt := time.Now().Unix()
+ model := info.UpstreamModelName
+
+ var (
+ usage = &dto.Usage{}
+ outputText strings.Builder
+ usageText strings.Builder
+ sentStart bool
+ sentStop bool
+ sawToolCall bool
+ streamErr *types.NewAPIError
+ )
+
+ toolCallIndexByID := make(map[string]int)
+ toolCallNameByID := make(map[string]string)
+ toolCallArgsByID := make(map[string]string)
+ toolCallNameSent := make(map[string]bool)
+ toolCallCanonicalIDByItemID := make(map[string]string)
+
+ sendStartIfNeeded := func() bool {
+ if sentStart {
+ return true
+ }
+ if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ return false
+ }
+ sentStart = true
+ return true
+ }
+
+ sendToolCallDelta := func(callID string, name string, argsDelta string) bool {
+ if callID == "" {
+ return true
+ }
+ if outputText.Len() > 0 {
+ // Prefer streaming assistant text over tool calls to match non-stream behavior.
+ return true
+ }
+ if !sendStartIfNeeded() {
+ return false
+ }
+
+ idx, ok := toolCallIndexByID[callID]
+ if !ok {
+ idx = len(toolCallIndexByID)
+ toolCallIndexByID[callID] = idx
+ }
+ if name != "" {
+ toolCallNameByID[callID] = name
+ }
+ if toolCallNameByID[callID] != "" {
+ name = toolCallNameByID[callID]
+ }
+
+ tool := dto.ToolCallResponse{
+ ID: callID,
+ Type: "function",
+ Function: dto.FunctionResponse{
+ Arguments: argsDelta,
+ },
+ }
+ tool.SetIndex(idx)
+ if name != "" && !toolCallNameSent[callID] {
+ tool.Function.Name = name
+ toolCallNameSent[callID] = true
+ }
+
+ chunk := &dto.ChatCompletionsStreamResponse{
+ Id: responseId,
+ Object: "chat.completion.chunk",
+ Created: createAt,
+ Model: model,
+ Choices: []dto.ChatCompletionsStreamResponseChoice{
+ {
+ Index: 0,
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
+ ToolCalls: []dto.ToolCallResponse{tool},
+ },
+ },
+ },
+ }
+ if err := helper.ObjectData(c, chunk); err != nil {
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ return false
+ }
+ sawToolCall = true
+
+ // Include tool call data in the local builder for fallback token estimation.
+ if tool.Function.Name != "" {
+ usageText.WriteString(tool.Function.Name)
+ }
+ if argsDelta != "" {
+ usageText.WriteString(argsDelta)
+ }
+ return true
+ }
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ if streamErr != nil {
+ return false
+ }
+
+ var streamResp dto.ResponsesStreamResponse
+ if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
+ logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
+ return true
+ }
+
+ switch streamResp.Type {
+ case "response.created":
+ if streamResp.Response != nil {
+ if streamResp.Response.Model != "" {
+ model = streamResp.Response.Model
+ }
+ if streamResp.Response.CreatedAt != 0 {
+ createAt = int64(streamResp.Response.CreatedAt)
+ }
+ }
+
+ case "response.output_text.delta":
+ if !sendStartIfNeeded() {
+ return false
+ }
+
+ if streamResp.Delta != "" {
+ outputText.WriteString(streamResp.Delta)
+ usageText.WriteString(streamResp.Delta)
+ delta := streamResp.Delta
+ chunk := &dto.ChatCompletionsStreamResponse{
+ Id: responseId,
+ Object: "chat.completion.chunk",
+ Created: createAt,
+ Model: model,
+ Choices: []dto.ChatCompletionsStreamResponseChoice{
+ {
+ Index: 0,
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
+ Content: &delta,
+ },
+ },
+ },
+ }
+ if err := helper.ObjectData(c, chunk); err != nil {
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ return false
+ }
+ }
+
+ case "response.output_item.added", "response.output_item.done":
+ if streamResp.Item == nil {
+ break
+ }
+ if streamResp.Item.Type != "function_call" {
+ break
+ }
+
+ itemID := strings.TrimSpace(streamResp.Item.ID)
+ callID := strings.TrimSpace(streamResp.Item.CallId)
+ if callID == "" {
+ callID = itemID
+ }
+ if itemID != "" && callID != "" {
+ toolCallCanonicalIDByItemID[itemID] = callID
+ }
+ name := strings.TrimSpace(streamResp.Item.Name)
+ if name != "" {
+ toolCallNameByID[callID] = name
+ }
+
+ newArgs := streamResp.Item.Arguments
+ prevArgs := toolCallArgsByID[callID]
+ argsDelta := ""
+ if newArgs != "" {
+ if strings.HasPrefix(newArgs, prevArgs) {
+ argsDelta = newArgs[len(prevArgs):]
+ } else {
+ argsDelta = newArgs
+ }
+ toolCallArgsByID[callID] = newArgs
+ }
+
+ if !sendToolCallDelta(callID, name, argsDelta) {
+ return false
+ }
+
+ case "response.function_call_arguments.delta":
+ itemID := strings.TrimSpace(streamResp.ItemID)
+ callID := toolCallCanonicalIDByItemID[itemID]
+ if callID == "" {
+ callID = itemID
+ }
+ if callID == "" {
+ break
+ }
+ toolCallArgsByID[callID] += streamResp.Delta
+ if !sendToolCallDelta(callID, "", streamResp.Delta) {
+ return false
+ }
+
+ case "response.function_call_arguments.done":
+
+ case "response.completed":
+ if streamResp.Response != nil {
+ if streamResp.Response.Model != "" {
+ model = streamResp.Response.Model
+ }
+ if streamResp.Response.CreatedAt != 0 {
+ createAt = int64(streamResp.Response.CreatedAt)
+ }
+ if streamResp.Response.Usage != nil {
+ if streamResp.Response.Usage.InputTokens != 0 {
+ usage.PromptTokens = streamResp.Response.Usage.InputTokens
+ usage.InputTokens = streamResp.Response.Usage.InputTokens
+ }
+ if streamResp.Response.Usage.OutputTokens != 0 {
+ usage.CompletionTokens = streamResp.Response.Usage.OutputTokens
+ usage.OutputTokens = streamResp.Response.Usage.OutputTokens
+ }
+ if streamResp.Response.Usage.TotalTokens != 0 {
+ usage.TotalTokens = streamResp.Response.Usage.TotalTokens
+ } else {
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ }
+ if streamResp.Response.Usage.InputTokensDetails != nil {
+ usage.PromptTokensDetails.CachedTokens = streamResp.Response.Usage.InputTokensDetails.CachedTokens
+ usage.PromptTokensDetails.ImageTokens = streamResp.Response.Usage.InputTokensDetails.ImageTokens
+ usage.PromptTokensDetails.AudioTokens = streamResp.Response.Usage.InputTokensDetails.AudioTokens
+ }
+ if streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
+ usage.CompletionTokenDetails.ReasoningTokens = streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens
+ }
+ }
+ }
+
+ if !sendStartIfNeeded() {
+ return false
+ }
+ if !sentStop {
+ finishReason := "stop"
+ if sawToolCall && outputText.Len() == 0 {
+ finishReason = "tool_calls"
+ }
+ stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
+ if err := helper.ObjectData(c, stop); err != nil {
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ return false
+ }
+ sentStop = true
+ }
+
+ case "response.error", "response.failed":
+ if streamResp.Response != nil {
+ if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
+ streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
+ return false
+ }
+ }
+ streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ return false
+
+ default:
+ }
+
+ return true
+ })
+
+ if streamErr != nil {
+ return nil, streamErr
+ }
+
+ if usage.TotalTokens == 0 {
+ usage = service.ResponseText2Usage(c, usageText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
+ }
+
+ if !sentStart {
+ if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ }
+ }
+ if !sentStop {
+ finishReason := "stop"
+ if sawToolCall && outputText.Len() == 0 {
+ finishReason = "tool_calls"
+ }
+ stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
+ if err := helper.ObjectData(c, stop); err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ }
+ }
+ if info.ShouldIncludeUsage && usage != nil {
+ if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ }
+ }
+
+ helper.Done(c)
+ return usage, nil
+}
diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go
index dd21fb75..6051c7e8 100644
--- a/relay/channel/task/doubao/adaptor.go
+++ b/relay/channel/task/doubao/adaptor.go
@@ -6,6 +6,9 @@ import (
"fmt"
"io"
"net/http"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
@@ -23,18 +26,36 @@ import (
// ============================
type ContentItem struct {
- Type string `json:"type"` // "text" or "image_url"
- Text string `json:"text,omitempty"` // for text type
- ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
+ Type string `json:"type"` // "text", "image_url" or "video"
+ Text string `json:"text,omitempty"` // for text type
+ ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
+ Video *VideoReference `json:"video,omitempty"` // for video (sample) type
}
type ImageURL struct {
URL string `json:"url"`
}
+type VideoReference struct {
+ URL string `json:"url"` // Draft video URL
+}
+
type requestPayload struct {
- Model string `json:"model"`
- Content []ContentItem `json:"content"`
+ Model string `json:"model"`
+ Content []ContentItem `json:"content"`
+ CallbackURL string `json:"callback_url,omitempty"`
+ ReturnLastFrame *dto.BoolValue `json:"return_last_frame,omitempty"`
+ ServiceTier string `json:"service_tier,omitempty"`
+ ExecutionExpiresAfter dto.IntValue `json:"execution_expires_after,omitempty"`
+ GenerateAudio *dto.BoolValue `json:"generate_audio,omitempty"`
+ Draft *dto.BoolValue `json:"draft,omitempty"`
+ Resolution string `json:"resolution,omitempty"`
+ Ratio string `json:"ratio,omitempty"`
+ Duration dto.IntValue `json:"duration,omitempty"`
+ Frames dto.IntValue `json:"frames,omitempty"`
+ Seed dto.IntValue `json:"seed,omitempty"`
+ CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
+ Watermark *dto.BoolValue `json:"watermark,omitempty"`
}
type responsePayload struct {
@@ -53,6 +74,7 @@ type responseTask struct {
Duration int `json:"duration"`
Ratio string `json:"ratio"`
FramesPerSecond int `json:"framespersecond"`
+ ServiceTier string `json:"service_tier"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
@@ -98,16 +120,16 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
// BuildRequestBody converts request into Doubao specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- v, exists := c.Get("task_request")
- if !exists {
- return nil, fmt.Errorf("request not found in context")
+ req, err := relaycommon.GetTaskRequest(c)
+ if err != nil {
+ return nil, err
}
- req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
+ info.UpstreamModelName = body.Model
data, err := json.Marshal(body)
if err != nil {
return nil, err
@@ -141,7 +163,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
- c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
+ ov := dto.NewOpenAIVideo()
+ ov.ID = dResp.ID
+ ov.TaskID = dResp.ID
+ ov.CreatedAt = time.Now().Unix()
+ ov.Model = info.OriginModelName
+
+ c.JSON(http.StatusOK, ov)
return dResp.ID, responseBody, nil
}
@@ -204,12 +232,15 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
}
}
- // TODO: Add support for additional parameters from metadata
- // such as ratio, duration, seed, etc.
- // metadata := req.Metadata
- // if metadata != nil {
- // // Parse and apply metadata parameters
- // }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
return &r, nil
}
@@ -229,7 +260,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
case "pending", "queued":
taskResult.Status = model.TaskStatusQueued
taskResult.Progress = "10%"
- case "processing":
+ case "processing", "running":
taskResult.Status = model.TaskStatusInProgress
taskResult.Progress = "50%"
case "succeeded":
@@ -251,3 +282,30 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return &taskResult, nil
}
+
+func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
+ var dResp responseTask
+ if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
+ return nil, errors.Wrap(err, "unmarshal doubao task data failed")
+ }
+
+ openAIVideo := dto.NewOpenAIVideo()
+ openAIVideo.ID = originTask.TaskID
+ openAIVideo.TaskID = originTask.TaskID
+ openAIVideo.Status = originTask.Status.ToVideoStatus()
+ openAIVideo.SetProgressStr(originTask.Progress)
+ openAIVideo.SetMetadata("url", dResp.Content.VideoURL)
+ openAIVideo.CreatedAt = originTask.CreatedAt
+ openAIVideo.CompletedAt = originTask.UpdatedAt
+ openAIVideo.Model = originTask.Properties.OriginModelName
+
+ if dResp.Status == "failed" {
+ openAIVideo.Error = &dto.OpenAIVideoError{
+ Message: "task failed",
+ Code: "failed",
+ }
+ }
+
+ jsonData, _ := common.Marshal(openAIVideo)
+ return jsonData, nil
+}
diff --git a/relay/channel/task/doubao/constants.go b/relay/channel/task/doubao/constants.go
index 74b416c6..13b1b1d9 100644
--- a/relay/channel/task/doubao/constants.go
+++ b/relay/channel/task/doubao/constants.go
@@ -4,6 +4,7 @@ var ModelList = []string{
"doubao-seedance-1-0-pro-250528",
"doubao-seedance-1-0-lite-t2v",
"doubao-seedance-1-0-lite-i2v",
+ "doubao-seedance-1-5-pro-251215",
}
var ChannelName = "doubao-video"
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
index 91d3f236..1522a967 100644
--- a/relay/channel/task/jimeng/adaptor.go
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -17,6 +17,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
+ "github.com/samber/lo"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
@@ -409,14 +410,15 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
// 即梦视频3.0 ReqKey转换
// https://www.volcengine.com/docs/85621/1792707
+ imageLen := lo.Max([]int{len(req.Images), len(r.BinaryDataBase64), len(r.ImageUrls)})
if strings.Contains(r.ReqKey, "jimeng_v30") {
if r.ReqKey == "jimeng_v30_pro" {
// 3.0 pro只有固定的jimeng_ti2v_v30_pro
r.ReqKey = "jimeng_ti2v_v30_pro"
- } else if len(req.Images) > 1 {
+ } else if imageLen > 1 {
// 多张图片:首尾帧生成
r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1), "p")
- } else if len(req.Images) == 1 {
+ } else if imageLen == 1 {
// 单张图片:图生视频
r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1), "p")
} else {
diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go
new file mode 100644
index 00000000..38dae3c5
--- /dev/null
+++ b/relay/chat_completions_via_responses.go
@@ -0,0 +1,162 @@
+package relay
+
+import (
+ "bytes"
+ "net/http"
+ "strings"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/relay/channel"
+ openaichannel "github.com/QuantumNous/new-api/relay/channel/openai"
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
+ relayconstant "github.com/QuantumNous/new-api/relay/constant"
+ "github.com/QuantumNous/new-api/service"
+ "github.com/QuantumNous/new-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) {
+ if info == nil || request == nil {
+ return
+ }
+ if info.ChannelSetting.SystemPrompt == "" {
+ return
+ }
+
+ systemRole := request.GetSystemRoleName()
+
+ containSystemPrompt := false
+ for _, message := range request.Messages {
+ if message.Role == systemRole {
+ containSystemPrompt = true
+ break
+ }
+ }
+ if !containSystemPrompt {
+ systemMessage := dto.Message{
+ Role: systemRole,
+ Content: info.ChannelSetting.SystemPrompt,
+ }
+ request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
+ return
+ }
+
+ if !info.ChannelSetting.SystemPromptOverride {
+ return
+ }
+
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ for i, message := range request.Messages {
+ if message.Role != systemRole {
+ continue
+ }
+ if message.IsStringContent() {
+ request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
+ return
+ }
+ contents := message.ParseContent()
+ contents = append([]dto.MediaContent{
+ {
+ Type: dto.ContentTypeText,
+ Text: info.ChannelSetting.SystemPrompt,
+ },
+ }, contents...)
+ request.Messages[i].Content = contents
+ return
+ }
+}
+
+func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
+ overrideCtx := relaycommon.BuildParamOverrideContext(info)
+ chatJSON, err := common.Marshal(request)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ if len(info.ParamOverride) > 0 {
+ chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ var overriddenChatReq dto.GeneralOpenAIRequest
+ if err := common.Unmarshal(chatJSON, &overriddenChatReq); err != nil {
+ return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+
+ responsesReq, err := service.ChatCompletionsRequestToResponsesRequest(&overriddenChatReq)
+ if err != nil {
+ return nil, types.NewErrorWithStatusCode(err, types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+ info.AppendRequestConversion(types.RelayFormatOpenAIResponses)
+
+ savedRelayMode := info.RelayMode
+ savedRequestURLPath := info.RequestURLPath
+ defer func() {
+ info.RelayMode = savedRelayMode
+ info.RequestURLPath = savedRequestURLPath
+ }()
+
+ info.RelayMode = relayconstant.RelayModeResponses
+ info.RequestURLPath = "/v1/responses"
+
+ convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *responsesReq)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
+
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ var httpResp *http.Response
+ resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+ if resp == nil {
+ return nil, types.NewOpenAIError(nil, types.ErrorCodeBadResponse, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ httpResp = resp.(*http.Response)
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return nil, newApiErr
+ }
+
+ if info.IsStream {
+ usage, newApiErr := openaichannel.OaiResponsesToChatStreamHandler(c, info, httpResp)
+ if newApiErr != nil {
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return nil, newApiErr
+ }
+ return usage, nil
+ }
+
+ usage, newApiErr := openaichannel.OaiResponsesToChatHandler(c, info, httpResp)
+ if newApiErr != nil {
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return nil, newApiErr
+ }
+ return usage, nil
+}
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index 7a18c173..7e05116d 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -110,6 +110,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
diff --git a/relay/common/override.go b/relay/common/override.go
index 872c960f..1a0c2478 100644
--- a/relay/common/override.go
+++ b/relay/common/override.go
@@ -570,18 +570,19 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
// 目前内置以下字段:
-// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
-// - upstream_model:始终为通道映射后的上游模型名。
+// - upstream_model/model:始终为通道映射后的上游模型名。
// - original_model:请求最初指定的模型名。
+// - request_path:请求路径
+// - is_channel_test:是否为渠道测试请求(同 is_test)。
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
- if info == nil || info.ChannelMeta == nil {
+ if info == nil {
return nil
}
ctx := make(map[string]interface{})
- if info.UpstreamModelName != "" {
- ctx["model"] = info.UpstreamModelName
- ctx["upstream_model"] = info.UpstreamModelName
+ if info.ChannelMeta != nil && info.ChannelMeta.UpstreamModelName != "" {
+ ctx["model"] = info.ChannelMeta.UpstreamModelName
+ ctx["upstream_model"] = info.ChannelMeta.UpstreamModelName
}
if info.OriginModelName != "" {
ctx["original_model"] = info.OriginModelName
@@ -590,8 +591,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
}
}
- if len(ctx) == 0 {
- return nil
+ if info.RequestURLPath != "" {
+ requestPath := info.RequestURLPath
+ if requestPath != "" {
+ ctx["request_path"] = requestPath
+ }
}
+
+ ctx["is_channel_test"] = info.IsChannelTest
return ctx
}
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index 1b9762fe..5c24ce57 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -115,11 +115,16 @@ type RelayInfo struct {
SendResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
IsClaudeBetaQuery bool // /v1/messages?beta=true
+ IsChannelTest bool // channel test request
PriceData types.PriceData
Request dto.Request
+ // RequestConversionChain records request format conversions in order, e.g.
+ // ["openai", "openai_responses"] or ["openai", "claude"].
+ RequestConversionChain []types.RelayFormat
+
ThinkingContentInfo
TokenCountMeta
*ClaudeConvertInfo
@@ -273,6 +278,7 @@ var streamSupportedChannels = map[int]bool{
constant.ChannelTypeZhipu_v4: true,
constant.ChannelTypeAli: true,
constant.ChannelTypeSubmodel: true,
+ constant.ChannelTypeCodex: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -446,38 +452,83 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
}
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
+ var info *RelayInfo
+ var err error
switch relayFormat {
case types.RelayFormatOpenAI:
- return GenRelayInfoOpenAI(c, request), nil
+ info = GenRelayInfoOpenAI(c, request)
case types.RelayFormatOpenAIAudio:
- return GenRelayInfoOpenAIAudio(c, request), nil
+ info = GenRelayInfoOpenAIAudio(c, request)
case types.RelayFormatOpenAIImage:
- return GenRelayInfoImage(c, request), nil
+ info = GenRelayInfoImage(c, request)
case types.RelayFormatOpenAIRealtime:
- return GenRelayInfoWs(c, ws), nil
+ info = GenRelayInfoWs(c, ws)
case types.RelayFormatClaude:
- return GenRelayInfoClaude(c, request), nil
+ info = GenRelayInfoClaude(c, request)
case types.RelayFormatRerank:
if request, ok := request.(*dto.RerankRequest); ok {
- return GenRelayInfoRerank(c, request), nil
+ info = GenRelayInfoRerank(c, request)
+ break
}
- return nil, errors.New("request is not a RerankRequest")
+ err = errors.New("request is not a RerankRequest")
case types.RelayFormatGemini:
- return GenRelayInfoGemini(c, request), nil
+ info = GenRelayInfoGemini(c, request)
case types.RelayFormatEmbedding:
- return GenRelayInfoEmbedding(c, request), nil
+ info = GenRelayInfoEmbedding(c, request)
case types.RelayFormatOpenAIResponses:
if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
- return GenRelayInfoResponses(c, request), nil
+ info = GenRelayInfoResponses(c, request)
+ break
}
- return nil, errors.New("request is not a OpenAIResponsesRequest")
+ err = errors.New("request is not a OpenAIResponsesRequest")
case types.RelayFormatTask:
- return genBaseRelayInfo(c, nil), nil
+ info = genBaseRelayInfo(c, nil)
case types.RelayFormatMjProxy:
- return genBaseRelayInfo(c, nil), nil
+ info = genBaseRelayInfo(c, nil)
default:
- return nil, errors.New("invalid relay format")
+ err = errors.New("invalid relay format")
}
+
+ if err != nil {
+ return nil, err
+ }
+ if info == nil {
+ return nil, errors.New("failed to build relay info")
+ }
+
+ info.InitRequestConversionChain()
+ return info, nil
+}
+
+func (info *RelayInfo) InitRequestConversionChain() {
+ if info == nil {
+ return
+ }
+ if len(info.RequestConversionChain) > 0 {
+ return
+ }
+ if info.RelayFormat == "" {
+ return
+ }
+ info.RequestConversionChain = []types.RelayFormat{info.RelayFormat}
+}
+
+func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
+ if info == nil {
+ return
+ }
+ if format == "" {
+ return
+ }
+ if len(info.RequestConversionChain) == 0 {
+ info.RequestConversionChain = []types.RelayFormat{format}
+ return
+ }
+ last := info.RequestConversionChain[len(info.RequestConversionChain)-1]
+ if last == format {
+ return
+ }
+ info.RequestConversionChain = append(info.RequestConversionChain, format)
}
//func (info *RelayInfo) SetPromptTokens(promptTokens int) {
diff --git a/relay/common/request_conversion.go b/relay/common/request_conversion.go
new file mode 100644
index 00000000..96b728d2
--- /dev/null
+++ b/relay/common/request_conversion.go
@@ -0,0 +1,40 @@
+package common
+
+import (
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/types"
+)
+
+func GuessRelayFormatFromRequest(req any) (types.RelayFormat, bool) {
+ switch req.(type) {
+ case *dto.GeneralOpenAIRequest, dto.GeneralOpenAIRequest:
+ return types.RelayFormatOpenAI, true
+ case *dto.OpenAIResponsesRequest, dto.OpenAIResponsesRequest:
+ return types.RelayFormatOpenAIResponses, true
+ case *dto.ClaudeRequest, dto.ClaudeRequest:
+ return types.RelayFormatClaude, true
+ case *dto.GeminiChatRequest, dto.GeminiChatRequest:
+ return types.RelayFormatGemini, true
+ case *dto.EmbeddingRequest, dto.EmbeddingRequest:
+ return types.RelayFormatEmbedding, true
+ case *dto.RerankRequest, dto.RerankRequest:
+ return types.RelayFormatRerank, true
+ case *dto.ImageRequest, dto.ImageRequest:
+ return types.RelayFormatOpenAIImage, true
+ case *dto.AudioRequest, dto.AudioRequest:
+ return types.RelayFormatOpenAIAudio, true
+ default:
+ return "", false
+ }
+}
+
+func AppendRequestConversionFromRequest(info *RelayInfo, req any) {
+ if info == nil {
+ return
+ }
+ format, ok := GuessRelayFormatFromRequest(req)
+ if !ok {
+ return
+ }
+ info.AppendRequestConversion(format)
+}
diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go
index a536e165..eab5052d 100644
--- a/relay/compatible_handler.go
+++ b/relay/compatible_handler.go
@@ -14,6 +14,7 @@ import (
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
+ relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
@@ -73,9 +74,32 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)
+
+ passThroughGlobal := model_setting.GetGlobalSettings().PassThroughRequestEnabled
+ if info.RelayMode == relayconstant.RelayModeChatCompletions &&
+ !passThroughGlobal &&
+ !info.ChannelSetting.PassThroughBodyEnabled &&
+ shouldChatCompletionsViaResponses(info) {
+ applySystemPromptIfNeeded(c, info, request)
+ usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, request)
+ if newApiErr != nil {
+ return newApiErr
+ }
+
+ var containAudioTokens = usage.CompletionTokenDetails.AudioTokens > 0 || usage.PromptTokensDetails.AudioTokens > 0
+ var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName)
+
+ if containAudioTokens && containsAudioRatios {
+ service.PostAudioConsumeQuota(c, info, usage, "")
+ } else {
+ postConsumeQuota(c, info, usage)
+ }
+ return nil
+ }
+
var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+ if passThroughGlobal || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -89,6 +113,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
if info.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
@@ -193,6 +218,16 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return nil
}
+func shouldChatCompletionsViaResponses(info *relaycommon.RelayInfo) bool {
+ if info == nil {
+ return false
+ }
+ if info.RelayMode != relayconstant.RelayModeChatCompletions {
+ return false
+ }
+ return service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.OriginModelName)
+}
+
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
if usage == nil {
usage = &dto.Usage{
@@ -301,6 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
var audioInputQuota decimal.Decimal
var audioInputPrice float64
+ isClaudeUsageSemantic := relayInfo.ChannelType == constant.ChannelTypeAnthropic
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
@@ -308,14 +344,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
var cachedTokensWithRatio decimal.Decimal
if !dCacheTokens.IsZero() {
- if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
+ if !isClaudeUsageSemantic {
baseTokens = baseTokens.Sub(dCacheTokens)
}
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
var dCachedCreationTokensWithRatio decimal.Decimal
if !dCachedCreationTokens.IsZero() {
- if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
+ if !isClaudeUsageSemantic {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
}
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
@@ -425,6 +461,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
}
logContent := strings.Join(extraContent, ", ")
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
+ // For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
+ if isClaudeUsageSemantic {
+ other["claude"] = true
+ other["usage_semantic"] = "anthropic"
+ }
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go
index 2cedf02b..1a41756b 100644
--- a/relay/embedding_handler.go
+++ b/relay/embedding_handler.go
@@ -45,6 +45,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
index 79ffba51..779670b9 100644
--- a/relay/gemini_handler.go
+++ b/relay/gemini_handler.go
@@ -149,6 +149,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
diff --git a/relay/image_handler.go b/relay/image_handler.go
index f110f4e8..1ee790b7 100644
--- a/relay/image_handler.go
+++ b/relay/image_handler.go
@@ -57,6 +57,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
switch convertedRequest.(type) {
case *bytes.Buffer:
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index b838b313..3139c9a2 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/relay/channel/baidu_v2"
"github.com/QuantumNous/new-api/relay/channel/claude"
"github.com/QuantumNous/new-api/relay/channel/cloudflare"
+ "github.com/QuantumNous/new-api/relay/channel/codex"
"github.com/QuantumNous/new-api/relay/channel/cohere"
"github.com/QuantumNous/new-api/relay/channel/coze"
"github.com/QuantumNous/new-api/relay/channel/deepseek"
@@ -117,6 +118,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &minimax.Adaptor{}
case constant.APITypeReplicate:
return &replicate.Adaptor{}
+ case constant.APITypeCodex:
+ return &codex.Adaptor{}
}
return nil
}
@@ -148,7 +151,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
return &taskvertex.TaskAdaptor{}
case constant.ChannelTypeVidu:
return &taskVidu.TaskAdaptor{}
- case constant.ChannelTypeDoubaoVideo:
+ case constant.ChannelTypeDoubaoVideo, constant.ChannelTypeVolcEngine:
return &taskdoubao.TaskAdaptor{}
case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
return &tasksora.TaskAdaptor{}
diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go
index 9a50fd27..35c66a29 100644
--- a/relay/rerank_handler.go
+++ b/relay/rerank_handler.go
@@ -53,6 +53,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
diff --git a/relay/responses_handler.go b/relay/responses_handler.go
index 5c3d9a42..769437a1 100644
--- a/relay/responses_handler.go
+++ b/relay/responses_handler.go
@@ -53,6 +53,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
+ relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
diff --git a/router/api-router.go b/router/api-router.go
index 9b2bd061..f3ae4d97 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -156,6 +156,12 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
+ channelRoute.POST("/codex/oauth/start", controller.StartCodexOAuth)
+ channelRoute.POST("/codex/oauth/complete", controller.CompleteCodexOAuth)
+ channelRoute.POST("/:id/codex/oauth/start", controller.StartCodexOAuthForChannel)
+ channelRoute.POST("/:id/codex/oauth/complete", controller.CompleteCodexOAuthForChannel)
+ channelRoute.POST("/:id/codex/refresh", controller.RefreshCodexChannelCredential)
+ channelRoute.GET("/:id/codex/usage", controller.GetCodexChannelUsage)
channelRoute.POST("/ollama/pull", controller.OllamaPullModel)
channelRoute.POST("/ollama/pull/stream", controller.OllamaPullModelStream)
channelRoute.DELETE("/ollama/delete", controller.OllamaDeleteModel)
diff --git a/service/channel.go b/service/channel.go
index 8f8a3572..96bc1efe 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -57,9 +57,12 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
if types.IsSkipRetryError(err) {
return false
}
- if err.StatusCode == http.StatusUnauthorized {
+ if operation_setting.ShouldDisableByStatusCode(err.StatusCode) {
return true
}
+ //if err.StatusCode == http.StatusUnauthorized {
+ // return true
+ //}
if err.StatusCode == http.StatusForbidden {
switch channelType {
case constant.ChannelTypeGemini:
diff --git a/service/codex_credential_refresh.go b/service/codex_credential_refresh.go
new file mode 100644
index 00000000..0290fe51
--- /dev/null
+++ b/service/codex_credential_refresh.go
@@ -0,0 +1,104 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/model"
+)
+
+type CodexCredentialRefreshOptions struct {
+ ResetCaches bool
+}
+
+type CodexOAuthKey struct {
+ IDToken string `json:"id_token,omitempty"`
+ AccessToken string `json:"access_token,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+
+ AccountID string `json:"account_id,omitempty"`
+ LastRefresh string `json:"last_refresh,omitempty"`
+ Email string `json:"email,omitempty"`
+ Type string `json:"type,omitempty"`
+ Expired string `json:"expired,omitempty"`
+}
+
+func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
+ if strings.TrimSpace(raw) == "" {
+ return nil, errors.New("codex channel: empty oauth key")
+ }
+ var key CodexOAuthKey
+ if err := common.Unmarshal([]byte(raw), &key); err != nil {
+ return nil, errors.New("codex channel: invalid oauth key json")
+ }
+ return &key, nil
+}
+
+func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
+ ch, err := model.GetChannelById(channelID, true)
+ if err != nil {
+ return nil, nil, err
+ }
+ if ch == nil {
+ return nil, nil, fmt.Errorf("channel not found")
+ }
+ if ch.Type != constant.ChannelTypeCodex {
+ return nil, nil, fmt.Errorf("channel type is not Codex")
+ }
+
+ oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
+ if err != nil {
+ return nil, nil, err
+ }
+ if strings.TrimSpace(oauthKey.RefreshToken) == "" {
+ return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
+ }
+
+ refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ oauthKey.AccessToken = res.AccessToken
+ oauthKey.RefreshToken = res.RefreshToken
+ oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
+ oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
+ if strings.TrimSpace(oauthKey.Type) == "" {
+ oauthKey.Type = "codex"
+ }
+
+ if strings.TrimSpace(oauthKey.AccountID) == "" {
+ if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
+ oauthKey.AccountID = accountID
+ }
+ }
+ if strings.TrimSpace(oauthKey.Email) == "" {
+ if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
+ oauthKey.Email = email
+ }
+ }
+
+ encoded, err := common.Marshal(oauthKey)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
+ return nil, nil, err
+ }
+
+ if opts.ResetCaches {
+ model.InitChannelCache()
+ ResetProxyClientCache()
+ }
+
+ return oauthKey, ch, nil
+}
diff --git a/service/codex_credential_refresh_task.go b/service/codex_credential_refresh_task.go
new file mode 100644
index 00000000..627ab929
--- /dev/null
+++ b/service/codex_credential_refresh_task.go
@@ -0,0 +1,140 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/constant"
+ "github.com/QuantumNous/new-api/logger"
+ "github.com/QuantumNous/new-api/model"
+
+ "github.com/bytedance/gopkg/util/gopool"
+)
+
+const (
+ codexCredentialRefreshTickInterval = 10 * time.Minute
+ codexCredentialRefreshThreshold = 24 * time.Hour
+ codexCredentialRefreshBatchSize = 200
+ codexCredentialRefreshTimeout = 15 * time.Second
+)
+
+var (
+ codexCredentialRefreshOnce sync.Once
+ codexCredentialRefreshRunning atomic.Bool
+)
+
+func StartCodexCredentialAutoRefreshTask() {
+ codexCredentialRefreshOnce.Do(func() {
+ if !common.IsMasterNode {
+ return
+ }
+
+ gopool.Go(func() {
+ logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold))
+
+ ticker := time.NewTicker(codexCredentialRefreshTickInterval)
+ defer ticker.Stop()
+
+ runCodexCredentialAutoRefreshOnce()
+ for range ticker.C {
+ runCodexCredentialAutoRefreshOnce()
+ }
+ })
+ })
+}
+
+func runCodexCredentialAutoRefreshOnce() {
+ if !codexCredentialRefreshRunning.CompareAndSwap(false, true) {
+ return
+ }
+ defer codexCredentialRefreshRunning.Store(false)
+
+ ctx := context.Background()
+ now := time.Now()
+
+ var refreshed int
+ var scanned int
+
+ offset := 0
+ for {
+ var channels []*model.Channel
+ err := model.DB.
+ Select("id", "name", "key", "status", "channel_info").
+ Where("type = ? AND status = 1", constant.ChannelTypeCodex).
+ Order("id asc").
+ Limit(codexCredentialRefreshBatchSize).
+ Offset(offset).
+ Find(&channels).Error
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err))
+ return
+ }
+ if len(channels) == 0 {
+ break
+ }
+ offset += codexCredentialRefreshBatchSize
+
+ for _, ch := range channels {
+ if ch == nil {
+ continue
+ }
+ scanned++
+ if ch.ChannelInfo.IsMultiKey {
+ continue
+ }
+
+ rawKey := strings.TrimSpace(ch.Key)
+ if rawKey == "" {
+ continue
+ }
+
+ oauthKey, err := parseCodexOAuthKey(rawKey)
+ if err != nil {
+ continue
+ }
+
+ refreshToken := strings.TrimSpace(oauthKey.RefreshToken)
+ if refreshToken == "" {
+ continue
+ }
+
+ expiredAtRaw := strings.TrimSpace(oauthKey.Expired)
+ expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw)
+ if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold {
+ continue
+ }
+
+ refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout)
+ newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false})
+ cancel()
+ if err != nil {
+ logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err))
+ continue
+ }
+
+ refreshed++
+ logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired))
+ }
+ }
+
+ if refreshed > 0 {
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r))
+ }
+ }()
+ model.InitChannelCache()
+ }()
+ ResetProxyClientCache()
+ }
+
+ if common.DebugEnabled {
+ logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed)
+ }
+}
diff --git a/service/codex_oauth.go b/service/codex_oauth.go
new file mode 100644
index 00000000..4c2dce1c
--- /dev/null
+++ b/service/codex_oauth.go
@@ -0,0 +1,288 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+)
+
+const (
+ codexOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
+ codexOAuthAuthorizeURL = "https://auth.openai.com/oauth/authorize"
+ codexOAuthTokenURL = "https://auth.openai.com/oauth/token"
+ codexOAuthRedirectURI = "http://localhost:1455/auth/callback"
+ codexOAuthScope = "openid profile email offline_access"
+ codexJWTClaimPath = "https://api.openai.com/auth"
+ defaultHTTPTimeout = 20 * time.Second
+)
+
+type CodexOAuthTokenResult struct {
+ AccessToken string
+ RefreshToken string
+ ExpiresAt time.Time
+}
+
+type CodexOAuthAuthorizationFlow struct {
+ State string
+ Verifier string
+ Challenge string
+ AuthorizeURL string
+}
+
+func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
+ client := &http.Client{Timeout: defaultHTTPTimeout}
+ return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
+}
+
+func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
+ client := &http.Client{Timeout: defaultHTTPTimeout}
+ return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
+}
+
+func CreateCodexOAuthAuthorizationFlow() (*CodexOAuthAuthorizationFlow, error) {
+ state, err := createStateHex(16)
+ if err != nil {
+ return nil, err
+ }
+ verifier, challenge, err := generatePKCEPair()
+ if err != nil {
+ return nil, err
+ }
+ u, err := buildCodexAuthorizeURL(state, challenge)
+ if err != nil {
+ return nil, err
+ }
+ return &CodexOAuthAuthorizationFlow{
+ State: state,
+ Verifier: verifier,
+ Challenge: challenge,
+ AuthorizeURL: u,
+ }, nil
+}
+
+func refreshCodexOAuthToken(
+ ctx context.Context,
+ client *http.Client,
+ tokenURL string,
+ clientID string,
+ refreshToken string,
+) (*CodexOAuthTokenResult, error) {
+ rt := strings.TrimSpace(refreshToken)
+ if rt == "" {
+ return nil, errors.New("empty refresh_token")
+ }
+
+ form := url.Values{}
+ form.Set("grant_type", "refresh_token")
+ form.Set("refresh_token", rt)
+ form.Set("client_id", clientID)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var payload struct {
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ ExpiresIn int `json:"expires_in"`
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+ return nil, err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, fmt.Errorf("codex oauth refresh failed: status=%d", resp.StatusCode)
+ }
+
+ if strings.TrimSpace(payload.AccessToken) == "" || strings.TrimSpace(payload.RefreshToken) == "" || payload.ExpiresIn <= 0 {
+ return nil, errors.New("codex oauth refresh response missing fields")
+ }
+
+ return &CodexOAuthTokenResult{
+ AccessToken: strings.TrimSpace(payload.AccessToken),
+ RefreshToken: strings.TrimSpace(payload.RefreshToken),
+ ExpiresAt: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second),
+ }, nil
+}
+
+func exchangeCodexAuthorizationCode(
+ ctx context.Context,
+ client *http.Client,
+ tokenURL string,
+ clientID string,
+ code string,
+ verifier string,
+ redirectURI string,
+) (*CodexOAuthTokenResult, error) {
+ c := strings.TrimSpace(code)
+ v := strings.TrimSpace(verifier)
+ if c == "" {
+ return nil, errors.New("empty authorization code")
+ }
+ if v == "" {
+ return nil, errors.New("empty code_verifier")
+ }
+
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", clientID)
+ form.Set("code", c)
+ form.Set("code_verifier", v)
+ form.Set("redirect_uri", redirectURI)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var payload struct {
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ ExpiresIn int `json:"expires_in"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+ return nil, err
+ }
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, fmt.Errorf("codex oauth code exchange failed: status=%d", resp.StatusCode)
+ }
+ if strings.TrimSpace(payload.AccessToken) == "" || strings.TrimSpace(payload.RefreshToken) == "" || payload.ExpiresIn <= 0 {
+ return nil, errors.New("codex oauth token response missing fields")
+ }
+ return &CodexOAuthTokenResult{
+ AccessToken: strings.TrimSpace(payload.AccessToken),
+ RefreshToken: strings.TrimSpace(payload.RefreshToken),
+ ExpiresAt: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second),
+ }, nil
+}
+
+func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
+ u, err := url.Parse(codexOAuthAuthorizeURL)
+ if err != nil {
+ return "", err
+ }
+ q := u.Query()
+ q.Set("response_type", "code")
+ q.Set("client_id", codexOAuthClientID)
+ q.Set("redirect_uri", codexOAuthRedirectURI)
+ q.Set("scope", codexOAuthScope)
+ q.Set("code_challenge", challenge)
+ q.Set("code_challenge_method", "S256")
+ q.Set("state", state)
+ q.Set("id_token_add_organizations", "true")
+ q.Set("codex_cli_simplified_flow", "true")
+ q.Set("originator", "codex_cli_rs")
+ u.RawQuery = q.Encode()
+ return u.String(), nil
+}
+
+func createStateHex(nBytes int) (string, error) {
+ if nBytes <= 0 {
+ return "", errors.New("invalid state bytes length")
+ }
+ b := make([]byte, nBytes)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return fmt.Sprintf("%x", b), nil
+}
+
+func generatePKCEPair() (verifier string, challenge string, err error) {
+ b := make([]byte, 32)
+ if _, err := rand.Read(b); err != nil {
+ return "", "", err
+ }
+ verifier = base64.RawURLEncoding.EncodeToString(b)
+ sum := sha256.Sum256([]byte(verifier))
+ challenge = base64.RawURLEncoding.EncodeToString(sum[:])
+ return verifier, challenge, nil
+}
+
+func ExtractCodexAccountIDFromJWT(token string) (string, bool) {
+ claims, ok := decodeJWTClaims(token)
+ if !ok {
+ return "", false
+ }
+ raw, ok := claims[codexJWTClaimPath]
+ if !ok {
+ return "", false
+ }
+ obj, ok := raw.(map[string]any)
+ if !ok {
+ return "", false
+ }
+ v, ok := obj["chatgpt_account_id"]
+ if !ok {
+ return "", false
+ }
+ s, ok := v.(string)
+ if !ok {
+ return "", false
+ }
+ s = strings.TrimSpace(s)
+ if s == "" {
+ return "", false
+ }
+ return s, true
+}
+
+func ExtractEmailFromJWT(token string) (string, bool) {
+ claims, ok := decodeJWTClaims(token)
+ if !ok {
+ return "", false
+ }
+ v, ok := claims["email"]
+ if !ok {
+ return "", false
+ }
+ s, ok := v.(string)
+ if !ok {
+ return "", false
+ }
+ s = strings.TrimSpace(s)
+ if s == "" {
+ return "", false
+ }
+ return s, true
+}
+
+func decodeJWTClaims(token string) (map[string]any, bool) {
+ parts := strings.Split(token, ".")
+ if len(parts) != 3 {
+ return nil, false
+ }
+ payloadRaw, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ return nil, false
+ }
+ var claims map[string]any
+ if err := json.Unmarshal(payloadRaw, &claims); err != nil {
+ return nil, false
+ }
+ return claims, true
+}
diff --git a/service/codex_wham_usage.go b/service/codex_wham_usage.go
new file mode 100644
index 00000000..d27cbd9d
--- /dev/null
+++ b/service/codex_wham_usage.go
@@ -0,0 +1,56 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+)
+
+func FetchCodexWhamUsage(
+ ctx context.Context,
+ client *http.Client,
+ baseURL string,
+ accessToken string,
+ accountID string,
+) (statusCode int, body []byte, err error) {
+ if client == nil {
+ return 0, nil, fmt.Errorf("nil http client")
+ }
+ bu := strings.TrimRight(strings.TrimSpace(baseURL), "/")
+ if bu == "" {
+ return 0, nil, fmt.Errorf("empty baseURL")
+ }
+ at := strings.TrimSpace(accessToken)
+ aid := strings.TrimSpace(accountID)
+ if at == "" {
+ return 0, nil, fmt.Errorf("empty accessToken")
+ }
+ if aid == "" {
+ return 0, nil, fmt.Errorf("empty accountID")
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, bu+"/backend-api/wham/usage", nil)
+ if err != nil {
+ return 0, nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+at)
+ req.Header.Set("chatgpt-account-id", aid)
+ req.Header.Set("Accept", "application/json")
+ if req.Header.Get("originator") == "" {
+ req.Header.Set("originator", "codex_cli_rs")
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer resp.Body.Close()
+
+ body, err = io.ReadAll(resp.Body)
+ if err != nil {
+ return resp.StatusCode, nil, err
+ }
+ return resp.StatusCode, body, nil
+}
diff --git a/service/http_client.go b/service/http_client.go
index 783aac89..2c3168f2 100644
--- a/service/http_client.go
+++ b/service/http_client.go
@@ -40,6 +40,9 @@ func InitHttpClient() {
ForceAttemptHTTP2: true,
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
}
+ if common.TLSInsecureSkipVerify {
+ transport.TLSClientConfig = common.InsecureTLSConfig
+ }
if common.RelayTimeout == 0 {
httpClient = &http.Client{
@@ -102,13 +105,17 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
switch parsedURL.Scheme {
case "http", "https":
+ transport := &http.Transport{
+ MaxIdleConns: common.RelayMaxIdleConns,
+ MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
+ ForceAttemptHTTP2: true,
+ Proxy: http.ProxyURL(parsedURL),
+ }
+ if common.TLSInsecureSkipVerify {
+ transport.TLSClientConfig = common.InsecureTLSConfig
+ }
client := &http.Client{
- Transport: &http.Transport{
- MaxIdleConns: common.RelayMaxIdleConns,
- MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
- ForceAttemptHTTP2: true,
- Proxy: http.ProxyURL(parsedURL),
- },
+ Transport: transport,
CheckRedirect: checkRedirect,
}
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
@@ -137,17 +144,19 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
return nil, err
}
- client := &http.Client{
- Transport: &http.Transport{
- MaxIdleConns: common.RelayMaxIdleConns,
- MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
- ForceAttemptHTTP2: true,
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return dialer.Dial(network, addr)
- },
+ transport := &http.Transport{
+ MaxIdleConns: common.RelayMaxIdleConns,
+ MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
+ ForceAttemptHTTP2: true,
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return dialer.Dial(network, addr)
},
- CheckRedirect: checkRedirect,
}
+ if common.TLSInsecureSkipVerify {
+ transport.TLSClientConfig = common.InsecureTLSConfig
+ }
+
+ client := &http.Client{Transport: transport, CheckRedirect: checkRedirect}
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
proxyClientLock.Lock()
proxyClients[proxyURL] = client
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 1bd7df67..71a6bd32 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -70,9 +70,38 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["admin_info"] = adminInfo
appendRequestPath(ctx, relayInfo, other)
+ appendRequestConversionChain(relayInfo, other)
return other
}
+func appendRequestConversionChain(relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
+ if relayInfo == nil || other == nil {
+ return
+ }
+ if len(relayInfo.RequestConversionChain) == 0 {
+ return
+ }
+ chain := make([]string, 0, len(relayInfo.RequestConversionChain))
+ for _, f := range relayInfo.RequestConversionChain {
+ switch f {
+ case types.RelayFormatOpenAI:
+ chain = append(chain, "OpenAI Compatible")
+ case types.RelayFormatClaude:
+ chain = append(chain, "Claude Messages")
+ case types.RelayFormatGemini:
+ chain = append(chain, "Google Gemini")
+ case types.RelayFormatOpenAIResponses:
+ chain = append(chain, "OpenAI Responses")
+ default:
+ chain = append(chain, string(f))
+ }
+ }
+ if len(chain) == 0 {
+ return
+ }
+ other["request_conversion"] = chain
+}
+
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["ws"] = true
diff --git a/service/openai_chat_responses_compat.go b/service/openai_chat_responses_compat.go
new file mode 100644
index 00000000..2e887386
--- /dev/null
+++ b/service/openai_chat_responses_compat.go
@@ -0,0 +1,18 @@
+package service
+
+import (
+ "github.com/QuantumNous/new-api/dto"
+ "github.com/QuantumNous/new-api/service/openaicompat"
+)
+
+func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*dto.OpenAIResponsesRequest, error) {
+ return openaicompat.ChatCompletionsRequestToResponsesRequest(req)
+}
+
+func ResponsesResponseToChatCompletionsResponse(resp *dto.OpenAIResponsesResponse, id string) (*dto.OpenAITextResponse, *dto.Usage, error) {
+ return openaicompat.ResponsesResponseToChatCompletionsResponse(resp, id)
+}
+
+func ExtractOutputTextFromResponses(resp *dto.OpenAIResponsesResponse) string {
+ return openaicompat.ExtractOutputTextFromResponses(resp)
+}
diff --git a/service/openai_chat_responses_mode.go b/service/openai_chat_responses_mode.go
new file mode 100644
index 00000000..a655a38b
--- /dev/null
+++ b/service/openai_chat_responses_mode.go
@@ -0,0 +1,14 @@
+package service
+
+import (
+ "github.com/QuantumNous/new-api/service/openaicompat"
+ "github.com/QuantumNous/new-api/setting/model_setting"
+)
+
+func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
+ return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, model)
+}
+
+func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
+ return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, model)
+}
diff --git a/service/openaicompat/chat_to_responses.go b/service/openaicompat/chat_to_responses.go
new file mode 100644
index 00000000..3779db93
--- /dev/null
+++ b/service/openaicompat/chat_to_responses.go
@@ -0,0 +1,356 @@
+package openaicompat
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/dto"
+)
+
+func normalizeChatImageURLToString(v any) any {
+ switch vv := v.(type) {
+ case string:
+ return vv
+ case map[string]any:
+ if url := common.Interface2String(vv["url"]); url != "" {
+ return url
+ }
+ return v
+ case dto.MessageImageUrl:
+ if vv.Url != "" {
+ return vv.Url
+ }
+ return v
+ case *dto.MessageImageUrl:
+ if vv != nil && vv.Url != "" {
+ return vv.Url
+ }
+ return v
+ default:
+ return v
+ }
+}
+
+func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*dto.OpenAIResponsesRequest, error) {
+ if req == nil {
+ return nil, errors.New("request is nil")
+ }
+ if req.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if req.N > 1 {
+ return nil, fmt.Errorf("n>1 is not supported in responses compatibility mode")
+ }
+
+ var instructionsParts []string
+ inputItems := make([]map[string]any, 0, len(req.Messages))
+
+ for _, msg := range req.Messages {
+ role := strings.TrimSpace(msg.Role)
+ if role == "" {
+ continue
+ }
+
+ if role == "tool" || role == "function" {
+ callID := strings.TrimSpace(msg.ToolCallId)
+
+ var output any
+ if msg.Content == nil {
+ output = ""
+ } else if msg.IsStringContent() {
+ output = msg.StringContent()
+ } else {
+ if b, err := common.Marshal(msg.Content); err == nil {
+ output = string(b)
+ } else {
+ output = fmt.Sprintf("%v", msg.Content)
+ }
+ }
+
+ if callID == "" {
+ inputItems = append(inputItems, map[string]any{
+ "role": "user",
+ "content": fmt.Sprintf("[tool_output_missing_call_id] %v", output),
+ })
+ continue
+ }
+
+ inputItems = append(inputItems, map[string]any{
+ "type": "function_call_output",
+ "call_id": callID,
+ "output": output,
+ })
+ continue
+ }
+
+ // Prefer mapping system/developer messages into `instructions`.
+ if role == "system" || role == "developer" {
+ if msg.Content == nil {
+ continue
+ }
+ if msg.IsStringContent() {
+ if s := strings.TrimSpace(msg.StringContent()); s != "" {
+ instructionsParts = append(instructionsParts, s)
+ }
+ continue
+ }
+ parts := msg.ParseContent()
+ var sb strings.Builder
+ for _, part := range parts {
+ if part.Type == dto.ContentTypeText && strings.TrimSpace(part.Text) != "" {
+ if sb.Len() > 0 {
+ sb.WriteString("\n")
+ }
+ sb.WriteString(part.Text)
+ }
+ }
+ if s := strings.TrimSpace(sb.String()); s != "" {
+ instructionsParts = append(instructionsParts, s)
+ }
+ continue
+ }
+
+ item := map[string]any{
+ "role": role,
+ }
+
+ if msg.Content == nil {
+ item["content"] = ""
+ inputItems = append(inputItems, item)
+
+ if role == "assistant" {
+ for _, tc := range msg.ParseToolCalls() {
+ if strings.TrimSpace(tc.ID) == "" {
+ continue
+ }
+ if tc.Type != "" && tc.Type != "function" {
+ continue
+ }
+ name := strings.TrimSpace(tc.Function.Name)
+ if name == "" {
+ continue
+ }
+ inputItems = append(inputItems, map[string]any{
+ "type": "function_call",
+ "call_id": tc.ID,
+ "name": name,
+ "arguments": tc.Function.Arguments,
+ })
+ }
+ }
+ continue
+ }
+
+ if msg.IsStringContent() {
+ item["content"] = msg.StringContent()
+ inputItems = append(inputItems, item)
+
+ if role == "assistant" {
+ for _, tc := range msg.ParseToolCalls() {
+ if strings.TrimSpace(tc.ID) == "" {
+ continue
+ }
+ if tc.Type != "" && tc.Type != "function" {
+ continue
+ }
+ name := strings.TrimSpace(tc.Function.Name)
+ if name == "" {
+ continue
+ }
+ inputItems = append(inputItems, map[string]any{
+ "type": "function_call",
+ "call_id": tc.ID,
+ "name": name,
+ "arguments": tc.Function.Arguments,
+ })
+ }
+ }
+ continue
+ }
+
+ parts := msg.ParseContent()
+ contentParts := make([]map[string]any, 0, len(parts))
+ for _, part := range parts {
+ switch part.Type {
+ case dto.ContentTypeText:
+ contentParts = append(contentParts, map[string]any{
+ "type": "input_text",
+ "text": part.Text,
+ })
+ case dto.ContentTypeImageURL:
+ contentParts = append(contentParts, map[string]any{
+ "type": "input_image",
+ "image_url": normalizeChatImageURLToString(part.ImageUrl),
+ })
+ case dto.ContentTypeInputAudio:
+ contentParts = append(contentParts, map[string]any{
+ "type": "input_audio",
+ "input_audio": part.InputAudio,
+ })
+ case dto.ContentTypeFile:
+ contentParts = append(contentParts, map[string]any{
+ "type": "input_file",
+ "file": part.File,
+ })
+ case dto.ContentTypeVideoUrl:
+ contentParts = append(contentParts, map[string]any{
+ "type": "input_video",
+ "video_url": part.VideoUrl,
+ })
+ default:
+ contentParts = append(contentParts, map[string]any{
+ "type": part.Type,
+ })
+ }
+ }
+ item["content"] = contentParts
+ inputItems = append(inputItems, item)
+
+ if role == "assistant" {
+ for _, tc := range msg.ParseToolCalls() {
+ if strings.TrimSpace(tc.ID) == "" {
+ continue
+ }
+ if tc.Type != "" && tc.Type != "function" {
+ continue
+ }
+ name := strings.TrimSpace(tc.Function.Name)
+ if name == "" {
+ continue
+ }
+ inputItems = append(inputItems, map[string]any{
+ "type": "function_call",
+ "call_id": tc.ID,
+ "name": name,
+ "arguments": tc.Function.Arguments,
+ })
+ }
+ }
+ }
+
+ inputRaw, err := common.Marshal(inputItems)
+ if err != nil {
+ return nil, err
+ }
+
+ var instructionsRaw json.RawMessage
+ if len(instructionsParts) > 0 {
+ instructions := strings.Join(instructionsParts, "\n\n")
+ instructionsRaw, _ = common.Marshal(instructions)
+ }
+
+ var toolsRaw json.RawMessage
+ if req.Tools != nil {
+ tools := make([]map[string]any, 0, len(req.Tools))
+ for _, tool := range req.Tools {
+ switch tool.Type {
+ case "function":
+ tools = append(tools, map[string]any{
+ "type": "function",
+ "name": tool.Function.Name,
+ "description": tool.Function.Description,
+ "parameters": tool.Function.Parameters,
+ })
+ default:
+ // Best-effort: keep original tool shape for unknown types.
+ var m map[string]any
+ if b, err := common.Marshal(tool); err == nil {
+ _ = common.Unmarshal(b, &m)
+ }
+ if len(m) == 0 {
+ m = map[string]any{"type": tool.Type}
+ }
+ tools = append(tools, m)
+ }
+ }
+ toolsRaw, _ = common.Marshal(tools)
+ }
+
+ var toolChoiceRaw json.RawMessage
+ if req.ToolChoice != nil {
+ switch v := req.ToolChoice.(type) {
+ case string:
+ toolChoiceRaw, _ = common.Marshal(v)
+ default:
+ var m map[string]any
+ if b, err := common.Marshal(v); err == nil {
+ _ = common.Unmarshal(b, &m)
+ }
+ if m == nil {
+ toolChoiceRaw, _ = common.Marshal(v)
+ } else if t, _ := m["type"].(string); t == "function" {
+ // Chat: {"type":"function","function":{"name":"..."}}
+ // Responses: {"type":"function","name":"..."}
+ if name, ok := m["name"].(string); ok && name != "" {
+ toolChoiceRaw, _ = common.Marshal(map[string]any{
+ "type": "function",
+ "name": name,
+ })
+ } else if fn, ok := m["function"].(map[string]any); ok {
+ if name, ok := fn["name"].(string); ok && name != "" {
+ toolChoiceRaw, _ = common.Marshal(map[string]any{
+ "type": "function",
+ "name": name,
+ })
+ } else {
+ toolChoiceRaw, _ = common.Marshal(v)
+ }
+ } else {
+ toolChoiceRaw, _ = common.Marshal(v)
+ }
+ } else {
+ toolChoiceRaw, _ = common.Marshal(v)
+ }
+ }
+ }
+
+ var parallelToolCallsRaw json.RawMessage
+ if req.ParallelTooCalls != nil {
+ parallelToolCallsRaw, _ = common.Marshal(*req.ParallelTooCalls)
+ }
+
+ var textRaw json.RawMessage
+ if req.ResponseFormat != nil && req.ResponseFormat.Type != "" {
+ textRaw, _ = common.Marshal(map[string]any{
+ "format": req.ResponseFormat,
+ })
+ }
+
+ maxOutputTokens := req.MaxTokens
+ if req.MaxCompletionTokens > maxOutputTokens {
+ maxOutputTokens = req.MaxCompletionTokens
+ }
+
+ var topP *float64
+ if req.TopP != 0 {
+ topP = common.GetPointer(req.TopP)
+ }
+
+ out := &dto.OpenAIResponsesRequest{
+ Model: req.Model,
+ Input: inputRaw,
+ Instructions: instructionsRaw,
+ MaxOutputTokens: maxOutputTokens,
+ Stream: req.Stream,
+ Temperature: req.Temperature,
+ Text: textRaw,
+ ToolChoice: toolChoiceRaw,
+ Tools: toolsRaw,
+ TopP: topP,
+ User: req.User,
+ ParallelToolCalls: parallelToolCallsRaw,
+ Store: req.Store,
+ Metadata: req.Metadata,
+ }
+
+ if req.ReasoningEffort != "" && req.ReasoningEffort != "none" {
+ out.Reasoning = &dto.Reasoning{
+ Effort: req.ReasoningEffort,
+ }
+ }
+
+ return out, nil
+}
diff --git a/service/openaicompat/policy.go b/service/openaicompat/policy.go
new file mode 100644
index 00000000..39b11ce5
--- /dev/null
+++ b/service/openaicompat/policy.go
@@ -0,0 +1,18 @@
+package openaicompat
+
+import "github.com/QuantumNous/new-api/setting/model_setting"
+
+func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
+ if !policy.IsChannelEnabled(channelID) {
+ return false
+ }
+ return matchAnyRegex(policy.ModelPatterns, model)
+}
+
+func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
+ return ShouldChatCompletionsUseResponsesPolicy(
+ model_setting.GetGlobalSettings().ChatCompletionsToResponsesPolicy,
+ channelID,
+ model,
+ )
+}
diff --git a/service/openaicompat/regex.go b/service/openaicompat/regex.go
new file mode 100644
index 00000000..4ad5e929
--- /dev/null
+++ b/service/openaicompat/regex.go
@@ -0,0 +1,33 @@
+package openaicompat
+
+import (
+ "regexp"
+ "sync"
+)
+
+var compiledRegexCache sync.Map // map[string]*regexp.Regexp
+
+func matchAnyRegex(patterns []string, s string) bool {
+ if len(patterns) == 0 || s == "" {
+ return false
+ }
+ for _, pattern := range patterns {
+ if pattern == "" {
+ continue
+ }
+ re, ok := compiledRegexCache.Load(pattern)
+ if !ok {
+ compiled, err := regexp.Compile(pattern)
+ if err != nil {
+ // Treat invalid patterns as non-matching to avoid breaking runtime traffic.
+ continue
+ }
+ re = compiled
+ compiledRegexCache.Store(pattern, re)
+ }
+ if re.(*regexp.Regexp).MatchString(s) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/service/openaicompat/responses_to_chat.go b/service/openaicompat/responses_to_chat.go
new file mode 100644
index 00000000..abd03592
--- /dev/null
+++ b/service/openaicompat/responses_to_chat.go
@@ -0,0 +1,133 @@
+package openaicompat
+
+import (
+ "errors"
+ "strings"
+
+ "github.com/QuantumNous/new-api/dto"
+)
+
+func ResponsesResponseToChatCompletionsResponse(resp *dto.OpenAIResponsesResponse, id string) (*dto.OpenAITextResponse, *dto.Usage, error) {
+ if resp == nil {
+ return nil, nil, errors.New("response is nil")
+ }
+
+ text := ExtractOutputTextFromResponses(resp)
+
+ usage := &dto.Usage{}
+ if resp.Usage != nil {
+ if resp.Usage.InputTokens != 0 {
+ usage.PromptTokens = resp.Usage.InputTokens
+ usage.InputTokens = resp.Usage.InputTokens
+ }
+ if resp.Usage.OutputTokens != 0 {
+ usage.CompletionTokens = resp.Usage.OutputTokens
+ usage.OutputTokens = resp.Usage.OutputTokens
+ }
+ if resp.Usage.TotalTokens != 0 {
+ usage.TotalTokens = resp.Usage.TotalTokens
+ } else {
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ }
+ if resp.Usage.InputTokensDetails != nil {
+ usage.PromptTokensDetails.CachedTokens = resp.Usage.InputTokensDetails.CachedTokens
+ usage.PromptTokensDetails.ImageTokens = resp.Usage.InputTokensDetails.ImageTokens
+ usage.PromptTokensDetails.AudioTokens = resp.Usage.InputTokensDetails.AudioTokens
+ }
+ if resp.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
+ usage.CompletionTokenDetails.ReasoningTokens = resp.Usage.CompletionTokenDetails.ReasoningTokens
+ }
+ }
+
+ created := resp.CreatedAt
+
+ var toolCalls []dto.ToolCallResponse
+ if text == "" && len(resp.Output) > 0 {
+ for _, out := range resp.Output {
+ if out.Type != "function_call" {
+ continue
+ }
+ name := strings.TrimSpace(out.Name)
+ if name == "" {
+ continue
+ }
+ callId := strings.TrimSpace(out.CallId)
+ if callId == "" {
+ callId = strings.TrimSpace(out.ID)
+ }
+ toolCalls = append(toolCalls, dto.ToolCallResponse{
+ ID: callId,
+ Type: "function",
+ Function: dto.FunctionResponse{
+ Name: name,
+ Arguments: out.Arguments,
+ },
+ })
+ }
+ }
+
+ finishReason := "stop"
+ if len(toolCalls) > 0 {
+ finishReason = "tool_calls"
+ }
+
+ msg := dto.Message{
+ Role: "assistant",
+ Content: text,
+ }
+ if len(toolCalls) > 0 {
+ msg.SetToolCalls(toolCalls)
+ msg.Content = ""
+ }
+
+ out := &dto.OpenAITextResponse{
+ Id: id,
+ Object: "chat.completion",
+ Created: created,
+ Model: resp.Model,
+ Choices: []dto.OpenAITextResponseChoice{
+ {
+ Index: 0,
+ Message: msg,
+ FinishReason: finishReason,
+ },
+ },
+ Usage: *usage,
+ }
+
+ return out, usage, nil
+}
+
+func ExtractOutputTextFromResponses(resp *dto.OpenAIResponsesResponse) string {
+ if resp == nil || len(resp.Output) == 0 {
+ return ""
+ }
+
+ var sb strings.Builder
+
+ // Prefer assistant message outputs.
+ for _, out := range resp.Output {
+ if out.Type != "message" {
+ continue
+ }
+ if out.Role != "" && out.Role != "assistant" {
+ continue
+ }
+ for _, c := range out.Content {
+ if c.Type == "output_text" && c.Text != "" {
+ sb.WriteString(c.Text)
+ }
+ }
+ }
+ if sb.Len() > 0 {
+ return sb.String()
+ }
+ for _, out := range resp.Output {
+ for _, c := range out.Content {
+ if c.Text != "" {
+ sb.WriteString(c.Text)
+ }
+ }
+ }
+ return sb.String()
+}
diff --git a/setting/model_setting/global.go b/setting/model_setting/global.go
index f51ebc89..58017117 100644
--- a/setting/model_setting/global.go
+++ b/setting/model_setting/global.go
@@ -1,14 +1,36 @@
package model_setting
import (
+ "slices"
"strings"
"github.com/QuantumNous/new-api/setting/config"
)
+type ChatCompletionsToResponsesPolicy struct {
+ Enabled bool `json:"enabled"`
+ AllChannels bool `json:"all_channels"`
+ ChannelIDs []int `json:"channel_ids,omitempty"`
+ ModelPatterns []string `json:"model_patterns,omitempty"`
+}
+
+func (p ChatCompletionsToResponsesPolicy) IsChannelEnabled(channelID int) bool {
+ if !p.Enabled {
+ return false
+ }
+ if p.AllChannels {
+ return true
+ }
+ if channelID == 0 || len(p.ChannelIDs) == 0 {
+ return false
+ }
+ return slices.Contains(p.ChannelIDs, channelID)
+}
+
type GlobalSettings struct {
- PassThroughRequestEnabled bool `json:"pass_through_request_enabled"`
- ThinkingModelBlacklist []string `json:"thinking_model_blacklist"`
+ PassThroughRequestEnabled bool `json:"pass_through_request_enabled"`
+ ThinkingModelBlacklist []string `json:"thinking_model_blacklist"`
+ ChatCompletionsToResponsesPolicy ChatCompletionsToResponsesPolicy `json:"chat_completions_to_responses_policy"`
}
// 默认配置
@@ -18,6 +40,10 @@ var defaultOpenaiSettings = GlobalSettings{
"moonshotai/kimi-k2-thinking",
"kimi-k2-thinking",
},
+ ChatCompletionsToResponsesPolicy: ChatCompletionsToResponsesPolicy{
+ Enabled: false,
+ AllChannels: true,
+ },
}
// 全局实例
diff --git a/setting/model_setting/qwen.go b/setting/model_setting/qwen.go
new file mode 100644
index 00000000..ccab5759
--- /dev/null
+++ b/setting/model_setting/qwen.go
@@ -0,0 +1,50 @@
+package model_setting
+
+import (
+ "strings"
+
+ "github.com/QuantumNous/new-api/setting/config"
+)
+
+// QwenSettings defines Qwen model configuration. 注意bool要以enabled结尾才可以生效编辑
+type QwenSettings struct {
+ SyncImageModels []string `json:"sync_image_models"`
+}
+
+// 默认配置
+var defaultQwenSettings = QwenSettings{
+ SyncImageModels: []string{
+ "z-image",
+ "qwen-image",
+ "wan2.6",
+ "qwen-image-edit",
+ "qwen-image-edit-max",
+ "qwen-image-edit-max-2026-01-16",
+ "qwen-image-edit-plus",
+ "qwen-image-edit-plus-2025-12-15",
+ "qwen-image-edit-plus-2025-10-30",
+ },
+}
+
+// 全局实例
+var qwenSettings = defaultQwenSettings
+
+func init() {
+ // 注册到全局配置管理器
+ config.GlobalConfig.Register("qwen", &qwenSettings)
+}
+
+// GetQwenSettings
+func GetQwenSettings() *QwenSettings {
+ return &qwenSettings
+}
+
+// IsSyncImageModel
+func IsSyncImageModel(model string) bool {
+ for _, m := range qwenSettings.SyncImageModels {
+ if strings.Contains(model, m) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/setting/operation_setting/status_code_ranges.go b/setting/operation_setting/status_code_ranges.go
new file mode 100644
index 00000000..698c87c9
--- /dev/null
+++ b/setting/operation_setting/status_code_ranges.go
@@ -0,0 +1,184 @@
+package operation_setting
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+type StatusCodeRange struct {
+ Start int
+ End int
+}
+
+var AutomaticDisableStatusCodeRanges = []StatusCodeRange{{Start: 401, End: 401}}
+
+// Default behavior matches legacy hardcoded retry rules in controller/relay.go shouldRetry:
+// retry for 1xx, 3xx, 4xx(except 400/408), 5xx(except 504/524), and no retry for 2xx.
+var AutomaticRetryStatusCodeRanges = []StatusCodeRange{
+ {Start: 100, End: 199},
+ {Start: 300, End: 399},
+ {Start: 401, End: 407},
+ {Start: 409, End: 499},
+ {Start: 500, End: 503},
+ {Start: 505, End: 523},
+ {Start: 525, End: 599},
+}
+
+func AutomaticDisableStatusCodesToString() string {
+ return statusCodeRangesToString(AutomaticDisableStatusCodeRanges)
+}
+
+func AutomaticDisableStatusCodesFromString(s string) error {
+ ranges, err := ParseHTTPStatusCodeRanges(s)
+ if err != nil {
+ return err
+ }
+ AutomaticDisableStatusCodeRanges = ranges
+ return nil
+}
+
+func ShouldDisableByStatusCode(code int) bool {
+ return shouldMatchStatusCodeRanges(AutomaticDisableStatusCodeRanges, code)
+}
+
+func AutomaticRetryStatusCodesToString() string {
+ return statusCodeRangesToString(AutomaticRetryStatusCodeRanges)
+}
+
+func AutomaticRetryStatusCodesFromString(s string) error {
+ ranges, err := ParseHTTPStatusCodeRanges(s)
+ if err != nil {
+ return err
+ }
+ AutomaticRetryStatusCodeRanges = ranges
+ return nil
+}
+
+func ShouldRetryByStatusCode(code int) bool {
+ return shouldMatchStatusCodeRanges(AutomaticRetryStatusCodeRanges, code)
+}
+
+func statusCodeRangesToString(ranges []StatusCodeRange) string {
+ if len(ranges) == 0 {
+ return ""
+ }
+ parts := make([]string, 0, len(ranges))
+ for _, r := range ranges {
+ if r.Start == r.End {
+ parts = append(parts, strconv.Itoa(r.Start))
+ continue
+ }
+ parts = append(parts, fmt.Sprintf("%d-%d", r.Start, r.End))
+ }
+ return strings.Join(parts, ",")
+}
+
+func shouldMatchStatusCodeRanges(ranges []StatusCodeRange, code int) bool {
+ if code < 100 || code > 599 {
+ return false
+ }
+ for _, r := range ranges {
+ if code < r.Start {
+ return false
+ }
+ if code <= r.End {
+ return true
+ }
+ }
+ return false
+}
+
+func ParseHTTPStatusCodeRanges(input string) ([]StatusCodeRange, error) {
+ input = strings.TrimSpace(input)
+ if input == "" {
+ return nil, nil
+ }
+
+ input = strings.NewReplacer(",", ",").Replace(input)
+ segments := strings.Split(input, ",")
+
+ var ranges []StatusCodeRange
+ var invalid []string
+
+ for _, seg := range segments {
+ seg = strings.TrimSpace(seg)
+ if seg == "" {
+ continue
+ }
+ r, err := parseHTTPStatusCodeToken(seg)
+ if err != nil {
+ invalid = append(invalid, seg)
+ continue
+ }
+ ranges = append(ranges, r)
+ }
+
+ if len(invalid) > 0 {
+ return nil, fmt.Errorf("invalid http status code rules: %s", strings.Join(invalid, ", "))
+ }
+ if len(ranges) == 0 {
+ return nil, nil
+ }
+
+ sort.Slice(ranges, func(i, j int) bool {
+ if ranges[i].Start == ranges[j].Start {
+ return ranges[i].End < ranges[j].End
+ }
+ return ranges[i].Start < ranges[j].Start
+ })
+
+ merged := []StatusCodeRange{ranges[0]}
+ for _, r := range ranges[1:] {
+ last := &merged[len(merged)-1]
+ if r.Start <= last.End+1 {
+ if r.End > last.End {
+ last.End = r.End
+ }
+ continue
+ }
+ merged = append(merged, r)
+ }
+
+ return merged, nil
+}
+
+func parseHTTPStatusCodeToken(token string) (StatusCodeRange, error) {
+ token = strings.TrimSpace(token)
+ token = strings.ReplaceAll(token, " ", "")
+ if token == "" {
+ return StatusCodeRange{}, fmt.Errorf("empty token")
+ }
+
+ if strings.Contains(token, "-") {
+ parts := strings.Split(token, "-")
+ if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+ return StatusCodeRange{}, fmt.Errorf("invalid range token: %s", token)
+ }
+ start, err := strconv.Atoi(parts[0])
+ if err != nil {
+ return StatusCodeRange{}, fmt.Errorf("invalid range start: %s", token)
+ }
+ end, err := strconv.Atoi(parts[1])
+ if err != nil {
+ return StatusCodeRange{}, fmt.Errorf("invalid range end: %s", token)
+ }
+ if start > end {
+ return StatusCodeRange{}, fmt.Errorf("range start > end: %s", token)
+ }
+ if start < 100 || end > 599 {
+ return StatusCodeRange{}, fmt.Errorf("range out of bounds: %s", token)
+ }
+ return StatusCodeRange{Start: start, End: end}, nil
+ }
+
+ code, err := strconv.Atoi(token)
+ if err != nil {
+ return StatusCodeRange{}, fmt.Errorf("invalid status code: %s", token)
+ }
+ if code < 100 || code > 599 {
+ return StatusCodeRange{}, fmt.Errorf("status code out of bounds: %s", token)
+ }
+ return StatusCodeRange{Start: code, End: code}, nil
+}
diff --git a/setting/operation_setting/status_code_ranges_test.go b/setting/operation_setting/status_code_ranges_test.go
new file mode 100644
index 00000000..5801824a
--- /dev/null
+++ b/setting/operation_setting/status_code_ranges_test.go
@@ -0,0 +1,79 @@
+package operation_setting
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseHTTPStatusCodeRanges_CommaSeparated(t *testing.T) {
+ ranges, err := ParseHTTPStatusCodeRanges("401,403,500-599")
+ require.NoError(t, err)
+ require.Equal(t, []StatusCodeRange{
+ {Start: 401, End: 401},
+ {Start: 403, End: 403},
+ {Start: 500, End: 599},
+ }, ranges)
+}
+
+func TestParseHTTPStatusCodeRanges_MergeAndNormalize(t *testing.T) {
+ ranges, err := ParseHTTPStatusCodeRanges("500-505,504,401,403,402")
+ require.NoError(t, err)
+ require.Equal(t, []StatusCodeRange{
+ {Start: 401, End: 403},
+ {Start: 500, End: 505},
+ }, ranges)
+}
+
+func TestParseHTTPStatusCodeRanges_Invalid(t *testing.T) {
+ _, err := ParseHTTPStatusCodeRanges("99,600,foo,500-400,500-")
+ require.Error(t, err)
+}
+
+func TestParseHTTPStatusCodeRanges_NoComma_IsInvalid(t *testing.T) {
+ _, err := ParseHTTPStatusCodeRanges("401 403")
+ require.Error(t, err)
+}
+
+func TestShouldDisableByStatusCode(t *testing.T) {
+ orig := AutomaticDisableStatusCodeRanges
+ t.Cleanup(func() { AutomaticDisableStatusCodeRanges = orig })
+
+ AutomaticDisableStatusCodeRanges = []StatusCodeRange{
+ {Start: 401, End: 403},
+ {Start: 500, End: 599},
+ }
+
+ require.True(t, ShouldDisableByStatusCode(401))
+ require.True(t, ShouldDisableByStatusCode(403))
+ require.False(t, ShouldDisableByStatusCode(404))
+ require.True(t, ShouldDisableByStatusCode(500))
+ require.False(t, ShouldDisableByStatusCode(200))
+}
+
+func TestShouldRetryByStatusCode(t *testing.T) {
+ orig := AutomaticRetryStatusCodeRanges
+ t.Cleanup(func() { AutomaticRetryStatusCodeRanges = orig })
+
+ AutomaticRetryStatusCodeRanges = []StatusCodeRange{
+ {Start: 429, End: 429},
+ {Start: 500, End: 599},
+ }
+
+ require.True(t, ShouldRetryByStatusCode(429))
+ require.True(t, ShouldRetryByStatusCode(500))
+ require.False(t, ShouldRetryByStatusCode(400))
+ require.False(t, ShouldRetryByStatusCode(200))
+}
+
+func TestShouldRetryByStatusCode_DefaultMatchesLegacyBehavior(t *testing.T) {
+ require.False(t, ShouldRetryByStatusCode(200))
+ require.False(t, ShouldRetryByStatusCode(400))
+ require.True(t, ShouldRetryByStatusCode(401))
+ require.False(t, ShouldRetryByStatusCode(408))
+ require.True(t, ShouldRetryByStatusCode(429))
+ require.True(t, ShouldRetryByStatusCode(500))
+ require.False(t, ShouldRetryByStatusCode(504))
+ require.False(t, ShouldRetryByStatusCode(524))
+ require.True(t, ShouldRetryByStatusCode(599))
+}
diff --git a/types/error.go b/types/error.go
index b060a9db..e112eeef 100644
--- a/types/error.go
+++ b/types/error.go
@@ -130,6 +130,20 @@ func (e *NewAPIError) Error() string {
return e.Err.Error()
}
+func (e *NewAPIError) ErrorWithStatusCode() string {
+ if e == nil {
+ return ""
+ }
+ msg := e.Error()
+ if e.StatusCode == 0 {
+ return msg
+ }
+ if msg == "" {
+ return fmt.Sprintf("status_code=%d", e.StatusCode)
+ }
+ return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
+}
+
func (e *NewAPIError) MaskSensitiveError() string {
if e == nil {
return ""
@@ -144,6 +158,20 @@ func (e *NewAPIError) MaskSensitiveError() string {
return common.MaskSensitiveInfo(errStr)
}
+func (e *NewAPIError) MaskSensitiveErrorWithStatusCode() string {
+ if e == nil {
+ return ""
+ }
+ msg := e.MaskSensitiveError()
+ if e.StatusCode == 0 {
+ return msg
+ }
+ if msg == "" {
+ return fmt.Sprintf("status_code=%d", e.StatusCode)
+ }
+ return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
+}
+
func (e *NewAPIError) SetMessage(message string) {
e.Err = errors.New(message)
}
diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx
index d87fc349..5111f1f6 100644
--- a/web/src/components/auth/LoginForm.jsx
+++ b/web/src/components/auth/LoginForm.jsx
@@ -17,9 +17,10 @@ along with this program. If not, see
+ {rawText}
+
+