diff --git a/README.en.md b/README.en.md index c3be8381..10a3cdb0 100644 --- a/README.en.md +++ b/README.en.md @@ -1,10 +1,13 @@ +

+ 中文 | English +

![new-api](/web/public/logo.png) # New API -🍥 Next Generation LLM Gateway and AI Asset Management System +🍥 Next-Generation Large Model Gateway and AI Asset Management System Calcium-Ion%2Fnew-api | Trendshift @@ -33,171 +36,159 @@ > This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api) > [!IMPORTANT] -> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and relevant laws and regulations. Not to be used for illegal purposes. -> - This project is for personal learning only. Stability is not guaranteed, and no technical support is provided. +> - This project is for personal learning purposes only, with no guarantee of stability or technical support. +> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes. +> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China. + +## 📚 Documentation + +For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/) + +You can also access the AI-generated DeepWiki: +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api) ## ✨ Key Features -1. 🎨 New UI interface (some interfaces pending update) -2. 🌍 Multi-language support (work in progress) -3. 🎨 Added [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface support, [Integration Guide](Midjourney.md) -4. 💰 Online recharge support, configurable in system settings: - - [x] EasyPay -5. 🔍 Query usage quota by key: - - Works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) -6. 📑 Configurable items per page in pagination -7. 🔄 Compatible with original One API database (one-api.db) -8. 💵 Support per-request model pricing, configurable in System Settings - Operation Settings -9. ⚖️ Support channel **weighted random** selection -10. 📈 Data dashboard (console) -11. 🔒 Configurable model access per token -12. 🤖 Telegram authorization login support: - 1. System Settings - Configure Login Registration - Allow Telegram Login - 2. Send /setdomain command to [@Botfather](https://t.me/botfather) - 3. Select your bot, then enter http(s)://your-website/login - 4. Telegram Bot name is the bot username without @ -13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md) -14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md) -15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels -16. 🧠 Support for setting reasoning effort through model name suffix: - - Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`) - - Add suffix `-medium` to set medium reasoning effort - - Add suffix `-low` to set low reasoning effort -17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `` tags and concatenated to the content returned. -18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings` -19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits: - 1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings` - 2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits) +New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details: + +1. 🎨 Brand new UI interface +2. 🌍 Multi-language support +3. 💰 Online recharge functionality (YiPay) +4. 🔍 Support for querying usage quotas with keys (works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)) +5. 🔄 Compatible with the original One API database +6. 💵 Support for pay-per-use model pricing +7. ⚖️ Support for weighted random channel selection +8. 📈 Data dashboard (console) +9. 🔒 Token grouping and model restrictions +10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC) +11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank) +12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime) +13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat) +14. Support for entering chat interface via /chat2link route +15. 🧠 Support for setting reasoning effort through model name suffixes: + 1. OpenAI o-series models + - Add `-high` suffix for high reasoning effort (e.g.: `o3-mini-high`) + - Add `-medium` suffix for medium reasoning effort (e.g.: `o3-mini-medium`) + - Add `-low` suffix for low reasoning effort (e.g.: `o3-mini-low`) + 2. Claude thinking models + - Add `-thinking` suffix to enable thinking mode (e.g.: `claude-3-7-sonnet-20250219-thinking`) +16. 🔄 Thinking-to-content functionality +17. 🔄 Model rate limiting for users +18. 💰 Cache billing support, which allows billing at a set ratio when cache is hit: + 1. Set the `Prompt Cache Ratio` option in `System Settings-Operation Settings` + 2. Set `Prompt Cache Ratio` in the channel, range 0-1, e.g., setting to 0.5 means billing at 50% when cache is hit 3. Supported channels: - [x] OpenAI - - [x] Azure + - [x] Azure - [x] DeepSeek - - [ ] Claude + - [x] Claude ## Model Support -This version additionally supports: -1. Third-party model **gpts** (gpt-4-gizmo-*) -2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md) -3. Custom channels with full API URL support -4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md) -5. Rerank models, supporting [Cohere](https://cohere.ai/) and [Jina](https://jina.ai/), [Integration Guide](Rerank.md) -6. Dify -You can add custom models gpt-4-gizmo-* in channels. These are third-party models and cannot be called with official OpenAI keys. +This version supports multiple models, please refer to [API Documentation-Relay Interface](https://docs.newapi.pro/api) for details: -## Additional Configurations Beyond One API -- `GENERATE_DEFAULT_TOKEN`: Generate initial token for new users, default `false` -- `STREAMING_TIMEOUT`: Set streaming response timeout, default 60 seconds -- `DIFY_DEBUG`: Output workflow and node info to client for Dify channel, default `true` -- `FORCE_STREAM_OPTION`: Override client stream_options parameter, default `true` -- `GET_MEDIA_TOKEN`: Calculate image tokens, default `true` -- `GET_MEDIA_TOKEN_NOT_STREAM`: Calculate image tokens in non-stream mode, default `true` -- `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true` -- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated -- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE` -- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable -- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20` -- `CRYPTO_SECRET`: Encryption key for encrypting database content -- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview` -- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10` -- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2` +1. Third-party models **gpts** (gpt-4-gizmo-*) +2. Third-party channel [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [API Documentation](https://docs.newapi.pro/api/midjourney-proxy-image) +3. Third-party channel [Suno API](https://github.com/Suno-API/Suno-API) interface, [API Documentation](https://docs.newapi.pro/api/suno-music) +4. Custom channels, supporting full call address input +5. Rerank models ([Cohere](https://cohere.ai/) and [Jina](https://jina.ai/)), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank) +6. Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat) +7. Dify, currently only supports chatflow + +## Environment Variable Configuration + +For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables): + +- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false` +- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds +- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true` +- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true` +- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true` +- `GET_MEDIA_TOKEN_NOT_STREAM`: Whether to count image tokens in non-streaming cases, default is `true` +- `UPDATE_TASK`: Whether to update asynchronous tasks (Midjourney, Suno), default is `true` +- `COHERE_SAFETY_SETTING`: Cohere model safety settings, options are `NONE`, `CONTEXTUAL`, `STRICT`, default is `NONE` +- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16` +- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20` +- `CRYPTO_SECRET`: Encryption key used for encrypting database content +- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview` +- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes +- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` +- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false` ## Deployment +For detailed deployment guides, please refer to [Installation Guide-Deployment Methods](https://docs.newapi.pro/installation): + > [!TIP] -> Latest Docker image: `calciumion/new-api:latest` -> Default account: root, password: 123456 +> Latest Docker image: `calciumion/new-api:latest` -### Multi-Server Deployment -- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers. -- If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment. +### Multi-machine Deployment Considerations +- Environment variable `SESSION_SECRET` must be set, otherwise login status will be inconsistent across multiple machines +- If sharing Redis, `CRYPTO_SECRET` must be set, otherwise Redis content cannot be accessed across multiple machines -### Requirements -- Local database (default): SQLite (Docker deployment must mount `/data` directory) -- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6 +### Deployment Requirements +- Local database (default): SQLite (Docker deployment must mount the `/data` directory) +- Remote database: MySQL version >= 5.7.8, PgSQL version >= 9.6 -### Deployment with BT Panel -Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install. -After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation. -After installation, find **New-API** in the app store, click install, configure basic options to complete installation. -[Pictorial Guide](BT.md) +### Deployment Methods -### Docker Deployment +#### Using BaoTa Panel Docker Feature +Install BaoTa Panel (version **9.2.0** or above), find **New-API** in the application store and install it. +[Tutorial with images](./docs/BT.md) -### Using Docker Compose (Recommended) +#### Using Docker Compose (Recommended) ```shell -# Clone project +# Download the project git clone https://github.com/Calcium-Ion/new-api.git cd new-api # Edit docker-compose.yml as needed -# nano docker-compose.yml -# vim docker-compose.yml # Start docker-compose up -d ``` -#### Update Version +#### Using Docker Image Directly ```shell -docker-compose pull -docker-compose up -d -``` - -### Direct Docker Image Usage -```shell -# SQLite deployment: +# Using SQLite docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest -# MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed -# Example: +# Using MySQL docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest ``` -#### Update Version -```shell -# Pull the latest image -docker pull calciumion/new-api:latest -# Stop and remove the old container -docker stop new-api -docker rm new-api -# Run the new container with the same parameters as before -docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest -``` +## Channel Retry and Cache +Channel retry functionality has been implemented, you can set the number of retries in `Settings->Operation Settings->General Settings`. It is **recommended to enable caching**. -Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility): -```shell -docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR -``` +### Cache Configuration Method +1. `REDIS_CONN_STRING`: Set Redis as cache +2. `MEMORY_CACHE_ENABLED`: Enable memory cache (no need to set manually if Redis is set) -## Channel Retry -Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**. -If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request. +## API Documentation -### Cache Configuration -1. `REDIS_CONN_STRING`: Use Redis as cache - + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153` -2. `MEMORY_CACHE_ENABLED`: Enable memory cache, default `false` - + Example: `MEMORY_CACHE_ENABLED=true` +For detailed API documentation, please refer to [API Documentation](https://docs.newapi.pro/api): -### Why Some Errors Don't Retry -Error codes 400, 504, 524 won't retry -### To Enable Retry for 400 -In `Channel->Edit`, set `Status Code Override` to: -```json -{ - "400": "500" -} -``` - -## Integration Guides -- [Midjourney Integration](Midjourney.md) -- [Suno Integration](Suno.md) +- [Chat API](https://docs.newapi.pro/api/openai-chat) +- [Image API](https://docs.newapi.pro/api/openai-image) +- [Rerank API](https://docs.newapi.pro/api/jinaai-rerank) +- [Realtime API](https://docs.newapi.pro/api/openai-realtime) +- [Claude Chat API (messages)](https://docs.newapi.pro/api/anthropic-chat) ## Related Projects - [One API](https://github.com/songquanpeng/one-api): Original project - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support -- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-gen AI B/C solution -- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota by key +- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-generation AI one-stop B/C-end solution +- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota with key + +Other projects based on New API: +- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon): High-performance optimized version of New API +- [VoAPI](https://github.com/VoAPI/VoAPI): Frontend beautified version based on New API + +## Help and Support + +If you have any questions, please refer to [Help and Support](https://docs.newapi.pro/support): +- [Community Interaction](https://docs.newapi.pro/support/community-interaction) +- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues) +- [FAQ](https://docs.newapi.pro/support/faq) ## 🌟 Star History -[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date) \ No newline at end of file +[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date) diff --git a/README.md b/README.md index 6ac8839b..e9d1c154 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,9 @@ 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/) +也可访问AI生成的DeepWiki: +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api) + ## ✨ 主要特性 New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction): @@ -107,9 +110,10 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20` - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 -- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` +- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false` ## 部署 @@ -130,7 +134,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do #### 使用宝塔面板Docker功能部署 安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。 -[图文教程](BT.md) +[图文教程](./docs/BT.md) #### 使用Docker Compose部署(推荐) ```shell diff --git a/common/constants.go b/common/constants.go index dd4f3b04..bee00506 100644 --- a/common/constants.go +++ b/common/constants.go @@ -240,6 +240,7 @@ const ( ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 ChannelTypeXai = 48 + ChannelTypeCoze = 49 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{ "https://qianfan.baidubce.com", //46 "", //47 "https://api.x.ai", //48 + "https://api.coze.cn", //49 } diff --git a/constant/azure.go b/constant/azure.go new file mode 100644 index 00000000..d84040ce --- /dev/null +++ b/constant/azure.go @@ -0,0 +1,5 @@ +package constant + +import "time" + +var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix() diff --git a/constant/env.go b/constant/env.go index fae48625..612f3e8b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -31,7 +31,7 @@ func InitEnv() { GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) - AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") + AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 41f8d8f7..2bda0fd2 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -108,6 +108,13 @@ type DeepSeekUsageResponse struct { } `json:"balance_infos"` } +type OpenRouterCreditResponse struct { + Data struct { + TotalCredits float64 `json:"total_credits"` + TotalUsage float64 `json:"total_usage"` + } `json:"data"` +} + // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} @@ -281,6 +288,22 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { return response.TotalAvailable, nil } +func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) { + url := "https://openrouter.ai/api/v1/credits" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := OpenRouterCreditResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + balance := response.Data.TotalCredits - response.Data.TotalUsage + channel.UpdateBalance(balance) + return balance, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { @@ -307,6 +330,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return updateChannelSiliconFlowBalance(channel) case common.ChannelTypeDeepSeek: return updateChannelDeepSeekBalance(channel) + case common.ChannelTypeOpenRouter: + return updateChannelOpenRouterBalance(channel) default: return 0, errors.New("尚未实现") } diff --git a/controller/channel.go b/controller/channel.go index ad85fe24..a31e1f47 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -119,8 +119,11 @@ func FetchUpstreamModels(c *gin.Context) { baseURL = channel.GetBaseURL() } url := fmt.Sprintf("%s/v1/models", baseURL) - if channel.Type == common.ChannelTypeGemini { + switch channel.Type { + case common.ChannelTypeGemini: url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) + case common.ChannelTypeAli: + url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) } body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { diff --git a/controller/option.go b/controller/option.go index 81ef463c..250f16bb 100644 --- a/controller/option.go +++ b/controller/option.go @@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) { }) return } + case "ModelRequestRateLimitGroup": + err = setting.CheckModelRequestRateLimitGroup(option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } } err = model.UpdateOption(option.Key, option.Value) diff --git a/controller/relay.go b/controller/relay.go index 91477665..1a875dbc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,8 +4,6 @@ import ( "bytes" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "log" "net/http" @@ -20,6 +18,9 @@ import ( "one-api/relay/helper" "one-api/service" "strings" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { @@ -37,6 +38,10 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode err = relay.RerankHelper(c, relayMode) case relayconstant.RelayModeEmbeddings: err = relay.EmbeddingHelper(c) + case relayconstant.RelayModeResponses: + err = relay.ResponsesHelper(c) + case relayconstant.RelayModeGemini: + err = relay.GeminiHelper(c) default: err = relay.TextHelper(c) } diff --git a/controller/user.go b/controller/user.go index e194f531..fd53e743 100644 --- a/controller/user.go +++ b/controller/user.go @@ -592,7 +592,14 @@ func UpdateSelf(c *gin.Context) { user.Password = "" // rollback to what it should be cleanUser.Password = "" } - updatePassword := user.Password != "" + updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } if err := cleanUser.Update(updatePassword); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -608,6 +615,23 @@ func UpdateSelf(c *gin.Context) { return } +func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) { + var currentUser *model.User + currentUser, err = model.GetUserById(userId, true) + if err != nil { + return + } + if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) { + err = fmt.Errorf("原密码错误") + return + } + if newPassword == "" { + return + } + updatePassword = true + return +} + func DeleteUser(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { diff --git a/BT.md b/docs/installation/BT.md similarity index 98% rename from BT.md rename to docs/installation/BT.md index e57cdab7..b4ea5b2f 100644 --- a/BT.md +++ b/docs/installation/BT.md @@ -1,3 +1,3 @@ -密钥为环境变量SESSION_SECRET - -![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0) +密钥为环境变量SESSION_SECRET + +![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0) diff --git a/Midjourney.md b/docs/models/Midjourney.md similarity index 100% rename from Midjourney.md rename to docs/models/Midjourney.md diff --git a/Rerank.md b/docs/models/Rerank.md similarity index 100% rename from Rerank.md rename to docs/models/Rerank.md diff --git a/Suno.md b/docs/models/Suno.md similarity index 100% rename from Suno.md rename to docs/models/Suno.md diff --git a/dto/claude.go b/dto/claude.go index 8068feb8..36dfc02e 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -7,17 +7,18 @@ type ClaudeMetadata struct { } type ClaudeMediaMessage struct { - Type string `json:"type,omitempty"` - Text *string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - PartialJson *string `json:"partial_json,omitempty"` - Role string `json:"role,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` - Delta string `json:"delta,omitempty"` + Type string `json:"type,omitempty"` + Text *string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson *string `json:"partial_json,omitempty"` + Role string `json:"role,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + Delta string `json:"delta,omitempty"` + CacheControl json.RawMessage `json:"cache_control,omitempty"` // tool_calls Id string `json:"id,omitempty"` Name string `json:"name,omitempty"` diff --git a/dto/dalle.go b/dto/dalle.go index 562d5f1a..a1309b6c 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,6 +12,9 @@ type ImageRequest struct { Style string `json:"style,omitempty"` User string `json:"user,omitempty"` ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } type ImageResponse struct { diff --git a/dto/openai_request.go b/dto/openai_request.go index 652d8cce..a7325fe8 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -2,6 +2,7 @@ package dto import ( "encoding/json" + "one-api/common" "strings" ) @@ -28,7 +29,6 @@ type GeneralOpenAIRequest struct { MaxTokens uint `json:"max_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"` - //Reasoning json.RawMessage `json:"reasoning,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` @@ -43,6 +43,7 @@ type GeneralOpenAIRequest struct { ResponseFormat *ResponseFormat `json:"response_format,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"` Seed float64 `json:"seed,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` @@ -53,6 +54,16 @@ type GeneralOpenAIRequest struct { Audio any `json:"audio,omitempty"` EnableThinking any `json:"enable_thinking,omitempty"` // ali ExtraBody any `json:"extra_body,omitempty"` + WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + // OpenRouter Params + Reasoning json.RawMessage `json:"reasoning,omitempty"` +} + +func (r *GeneralOpenAIRequest) ToMap() map[string]any { + result := make(map[string]any) + data, _ := common.EncodeJson(r) + _ = common.DecodeJson(data, &result) + return result } type ToolCallRequest struct { @@ -72,11 +83,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int { +func (r *GeneralOpenAIRequest) GetMaxTokens() int { return int(r.MaxTokens) } -func (r GeneralOpenAIRequest) ParseInput() []string { +func (r *GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil } @@ -114,6 +125,9 @@ type MediaContent struct { ImageUrl any `json:"image_url,omitempty"` InputAudio any `json:"input_audio,omitempty"` File any `json:"file,omitempty"` + VideoUrl any `json:"video_url,omitempty"` + // OpenRouter Params + CacheControl json.RawMessage `json:"cache_control,omitempty"` } func (m *MediaContent) GetImageMedia() *MessageImageUrl { @@ -158,11 +172,16 @@ type MessageFile struct { FileId string `json:"file_id,omitempty"` } +type MessageVideoUrl struct { + Url string `json:"url"` +} + const ( ContentTypeText = "text" ContentTypeImageURL = "image_url" ContentTypeInputAudio = "input_audio" ContentTypeFile = "file" + ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 ) func (m *Message) GetPrefix() bool { @@ -346,6 +365,15 @@ func (m *Message) ParseContent() []MediaContent { } } } + case ContentTypeVideoUrl: + if videoUrl, ok := contentItem["video_url"].(string); ok { + contentList = append(contentList, MediaContent{ + Type: ContentTypeVideoUrl, + VideoUrl: &MessageVideoUrl{ + Url: videoUrl, + }, + }) + } } } } @@ -355,3 +383,54 @@ func (m *Message) ParseContent() []MediaContent { } return contentList } + +type WebSearchOptions struct { + SearchContextSize string `json:"search_context_size,omitempty"` + UserLocation json.RawMessage `json:"user_location,omitempty"` +} + +type OpenAIResponsesRequest struct { + Model string `json:"model"` + Input json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + Instructions json.RawMessage `json:"instructions,omitempty"` + MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Store bool `json:"store,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools []ResponsesToolsCall `json:"tools,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` +} + +type Reasoning struct { + Effort string `json:"effort,omitempty"` + Summary string `json:"summary,omitempty"` +} + +type ResponsesToolsCall struct { + Type string `json:"type"` + // Web Search + UserLocation json.RawMessage `json:"user_location,omitempty"` + SearchContextSize string `json:"search_context_size,omitempty"` + // File Search + VectorStoreIds []string `json:"vector_store_ids,omitempty"` + MaxNumResults uint `json:"max_num_results,omitempty"` + Filters json.RawMessage `json:"filters,omitempty"` + // Computer Use + DisplayWidth uint `json:"display_width,omitempty"` + DisplayHeight uint `json:"display_height,omitempty"` + Environment string `json:"environment,omitempty"` + // Function + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} diff --git a/dto/openai_response.go b/dto/openai_response.go index c2100ec8..790d4df8 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -1,5 +1,7 @@ package dto +import "encoding/json" + type SimpleResponse struct { Usage `json:"usage"` Error *OpenAIError `json:"error"` @@ -191,3 +193,68 @@ type OutputTokenDetails struct { AudioTokens int `json:"audio_tokens"` ReasoningTokens int `json:"reasoning_tokens"` } + +type OpenAIResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Status string `json:"status"` + Error *OpenAIError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + Instructions string `json:"instructions"` + MaxOutputTokens int `json:"max_output_tokens"` + Model string `json:"model"` + Output []ResponsesOutput `json:"output"` + ParallelToolCalls bool `json:"parallel_tool_calls"` + PreviousResponseID string `json:"previous_response_id"` + Reasoning *Reasoning `json:"reasoning"` + Store bool `json:"store"` + Temperature float64 `json:"temperature"` + ToolChoice string `json:"tool_choice"` + Tools []ResponsesToolsCall `json:"tools"` + TopP float64 `json:"top_p"` + Truncation string `json:"truncation"` + Usage *Usage `json:"usage"` + User json.RawMessage `json:"user"` + Metadata json.RawMessage `json:"metadata"` +} + +type IncompleteDetails struct { + Reasoning string `json:"reasoning"` +} + +type ResponsesOutput struct { + Type string `json:"type"` + ID string `json:"id"` + Status string `json:"status"` + Role string `json:"role"` + Content []ResponsesOutputContent `json:"content"` +} + +type ResponsesOutputContent struct { + Type string `json:"type"` + Text string `json:"text"` + Annotations []interface{} `json:"annotations"` +} + +const ( + BuildInToolWebSearchPreview = "web_search_preview" + BuildInToolFileSearch = "file_search" +) + +const ( + BuildInCallWebSearchCall = "web_search_call" +) + +const ( + ResponsesOutputTypeItemAdded = "response.output_item.added" + ResponsesOutputTypeItemDone = "response.output_item.done" +) + +// ResponsesStreamResponse 用于处理 /v1/responses 流式响应 +type ResponsesStreamResponse struct { + Type string `json:"type"` + Response *OpenAIResponsesResponse `json:"response,omitempty"` + Delta string `json:"delta,omitempty"` + Item *ResponsesOutput `json:"item,omitempty"` +} diff --git a/main.go b/main.go index 4bdc97bd..c286650f 100644 --- a/main.go +++ b/main.go @@ -80,6 +80,8 @@ func main() { // Initialize options model.InitOptionMap() + service.InitTokenEncoders() + if common.RedisEnabled { // for compatibility with old versions common.MemoryCacheEnabled = true @@ -87,9 +89,22 @@ func main() { if common.MemoryCacheEnabled { common.SysLog("memory cache enabled") common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) - model.InitChannelCache() - } - if common.MemoryCacheEnabled { + + // Add panic recovery and retry for InitChannelCache + func() { + defer func() { + if r := recover(); r != nil { + common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + // Retry once + _, fixErr := model.FixAbility() + if fixErr != nil { + common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + } + } + }() + model.InitChannelCache() + }() + go model.SyncOptions(common.SyncFrequency) go model.SyncChannelCache(common.SyncFrequency) } @@ -133,8 +148,6 @@ func main() { common.SysLog("pprof enabled") } - service.InitTokenEncoders() - // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { diff --git a/middleware/auth.go b/middleware/auth.go index fece4553..ce86bb36 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,13 +1,14 @@ package middleware import ( - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) func validUserInfo(username string, role int) bool { @@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) { c.Request.Header.Set("Authorization", "Bearer "+key) } } + // gemini api 从query中获取key + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + skKey := c.Query("key") + if skKey != "" { + c.Request.Header.Set("Authorization", "Bearer "+skKey) + } + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/middleware/distributor.go b/middleware/distributor.go index 51fd8fd1..1bfe1821 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent + relayMode := relayconstant.RelayModeGemini + modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) + if modelName != "" { + modelRequest.Model = modelName + } + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { err = common.UnmarshalBodyReusable(c, &modelRequest) } @@ -185,7 +193,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { - modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1") + modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech @@ -213,6 +221,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) + c.Set("channel_create_time", channel.CreatedTime) c.Set("channel_setting", channel.GetSetting()) c.Set("param_override", channel.GetParamOverride()) if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { @@ -239,5 +248,35 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeMokaAI: c.Set("api_version", channel.Other) + case common.ChannelTypeCoze: + c.Set("bot_id", channel.Other) } } + +// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名 +// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent +// 输出: gemini-2.0-flash +func extractModelNameFromGeminiPath(path string) string { + // 查找 "/models/" 的位置 + modelsPrefix := "/models/" + modelsIndex := strings.Index(path, modelsPrefix) + if modelsIndex == -1 { + return "" + } + + // 从 "/models/" 之后开始提取 + startIndex := modelsIndex + len(modelsPrefix) + if startIndex >= len(path) { + return "" + } + + // 查找 ":" 的位置,模型名在 ":" 之前 + colonIndex := strings.Index(path[startIndex:], ":") + if colonIndex == -1 { + // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分 + return path[startIndex:] + } + + // 返回模型名部分 + return path[startIndex : startIndex+colonIndex] +} diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..34caa59b 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/limiter" + "one-api/constant" "one-api/setting" "strconv" "time" @@ -93,25 +94,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g } //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 - totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 - tb := limiter.New(ctx, rdb) - allowed, err = tb.Allow( - ctx, - totalKey, - limiter.WithCapacity(int64(totalMaxCount)*duration), - limiter.WithRate(int64(totalMaxCount)), - limiter.WithRequested(duration), - ) + if totalMaxCount > 0 { + totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 + tb := limiter.New(ctx, rdb) + allowed, err = tb.Allow( + ctx, + totalKey, + limiter.WithCapacity(int64(totalMaxCount)*duration), + limiter.WithRate(int64(totalMaxCount)), + limiter.WithRequested(duration), + ) - if err != nil { - fmt.Println("检查总请求数限制失败:", err.Error()) - abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") - return - } + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } - if !allowed { - abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } } // 4. 处理请求 @@ -173,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) { totalMaxCount := setting.ModelRequestRateLimitCount successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 获取分组 + group := c.GetString("token_group") + if group == "" { + group = c.GetString(constant.ContextKeyUserGroup) + } + + //获取分组的限流配置 + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) + if found { + totalMaxCount = groupTotalCount + successMaxCount = groupSuccessCount + } + // 根据存储类型选择并执行限流处理器 if common.RedisEnabled { redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) diff --git a/model/ability.go b/model/ability.go index 52720307..38b0bd73 100644 --- a/model/ability.go +++ b/model/ability.go @@ -50,7 +50,7 @@ func getPriority(group string, model string, retry int) (int, error) { err := DB.Model(&Ability{}). Select("DISTINCT(priority)"). Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model). - Order("priority DESC"). // 按优先级降序排序 + Order("priority DESC"). // 按优先级降序排序 Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 if err != nil { @@ -261,12 +261,28 @@ func FixAbility() (int, error) { common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error())) return 0, err } - // Delete abilities of channels that are not in channel table - err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error - if err != nil { - common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error())) - return 0, err + + // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders + if len(channelIds) > 0 { + // Process deletion in chunks to avoid "too many placeholders" error + for _, chunk := range lo.Chunk(channelIds, 100) { + err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error + if err != nil { + common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error())) + return 0, err + } + } + } else { + // If no channels exist, delete all abilities + err = DB.Delete(&Ability{}).Error + if err != nil { + common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error())) + return 0, err + } + common.SysLog("Delete all abilities successfully") + return 0, nil } + common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds)) count += len(channelIds) @@ -275,17 +291,26 @@ func FixAbility() (int, error) { err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error if err != nil { common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error())) - return 0, err + return count, err } + var channels []Channel if len(abilityChannelIds) == 0 { err = DB.Find(&channels).Error } else { - err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error - } - if err != nil { - return 0, err + // Process query in chunks to avoid "too many placeholders" error + err = nil + for _, chunk := range lo.Chunk(abilityChannelIds, 100) { + var channelsChunk []Channel + err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error + if err != nil { + common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error())) + return count, err + } + channels = append(channels, channelsChunk...) + } } + for _, channel := range channels { err := channel.UpdateAbilities(nil) if err != nil { diff --git a/model/cache.go b/model/cache.go index 2d1c36bf..e2f83e22 100644 --- a/model/cache.go +++ b/model/cache.go @@ -16,6 +16,9 @@ var channelsIDM map[int]*Channel var channelSyncLock sync.RWMutex func InitChannelCache() { + if !common.MemoryCacheEnabled { + return + } newChannelId2channel := make(map[int]*Channel) var channels []*Channel DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) @@ -84,11 +87,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha if !common.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model, retry) } - + channelSyncLock.RLock() channels := group2model2channels[group][model] channelSyncLock.RUnlock() - + if len(channels) == 0 { return nil, errors.New("channel not found") } diff --git a/model/channel.go b/model/channel.go index 41e5e371..ed7a0a7e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string { return strings.Split(strings.Trim(channel.Models, ","), ",") } +func (channel *Channel) GetGroups() []string { + if channel.Group == "" { + return []string{} + } + groups := strings.Split(strings.Trim(channel.Group, ","), ",") + for i, group := range groups { + groups[i] = strings.TrimSpace(group) + } + return groups +} + func (channel *Channel) GetOtherInfo() map[string]interface{} { otherInfo := make(map[string]interface{}) if channel.OtherInfo != "" { diff --git a/model/option.go b/model/option.go index d575742f..d892b120 100644 --- a/model/option.go +++ b/model/option.go @@ -67,6 +67,7 @@ func InitOptionMap() { common.OptionMap["ServerAddress"] = "" common.OptionMap["WorkerUrl"] = setting.WorkerUrl common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey + common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled) common.OptionMap["PayAddress"] = "" common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" @@ -92,6 +93,7 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() @@ -256,6 +258,8 @@ func updateOptionMap(key string, value string) (err error) { setting.StopOnSensitiveEnabled = boolValue case "SMTPSSLEnabled": common.SMTPSSLEnabled = boolValue + case "WorkerAllowHttpImageRequestEnabled": + setting.WorkerAllowHttpImageRequestEnabled = boolValue } } switch key { @@ -338,6 +342,8 @@ func updateOptionMap(key string, value string) (err error) { setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitGroup": + err = setting.UpdateModelRequestRateLimitGroupByJSONString(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/model/user.go b/model/user.go index 0aea2ff5..1a3372aa 100644 --- a/model/user.go +++ b/model/user.go @@ -18,6 +18,7 @@ type User struct { Id int `json:"id"` Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` + OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database! DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` Role int `json:"role" gorm:"type:int;default:1"` // admin, common Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index e097dbe6..50255d0a 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -1,11 +1,12 @@ package channel import ( - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor interface { @@ -18,6 +19,7 @@ type Adaptor interface { ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) + ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 0cbcef44..31e926d6 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -3,7 +3,6 @@ package ali import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -32,6 +33,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) case constant.RelayModeImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) default: fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } @@ -54,6 +57,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + + // fix: ali parameter.enable_thinking must be set to false for non-streaming calls + if !info.IsStream { + request.EnableThinking = false + } + switch info.RelayMode { default: aliReq := requestOpenAI2Ali(*request) @@ -79,6 +88,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8b2ca889..1d733bd4 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -1,16 +1,23 @@ package channel import ( + "context" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" + "sync" + "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { @@ -55,6 +62,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + if common2.DebugEnabled { + println("fullRequestURL:", fullRequestURL) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) @@ -94,6 +104,65 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return targetConn, nil } +func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { + pingerCtx, stopPinger := context.WithCancel(context.Background()) + + gopool.Go(func() { + defer func() { + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + }() + + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + // 退出时清理 ticker + defer ticker.Stop() + + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + // 发送 ping 数据 + case <-ticker.C: + if err := sendPingData(c, &pingMutex); err != nil { + return + } + // 收到退出信号 + case <-pingerCtx.Done(): + return + // request 结束 + case <-c.Request.Context().Done(): + return + } + } + }) + + return stopPinger +} + +func sendPingData(c *gin.Context, mutex *sync.Mutex) error { + mutex.Lock() + defer mutex.Unlock() + + err := helper.PingData(c) + if err != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return err + } + + if common2.DebugEnabled { + println("SSE ping data sent.") + } + return nil +} + func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { var client *http.Client var err error @@ -105,13 +174,28 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } + + if info.IsStream { + helper.SetEventStreamHeaders(c) + + // 处理流式请求的 ping 保活 + generalSettings := operation_setting.GetGeneralSetting() + if generalSettings.PingIntervalEnabled { + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + stopPinger := startPingKeepAlive(c, pingInterval) + defer stopPinger() + } + } + resp, err := client.Do(req) + if err != nil { return nil, err } if resp == nil { return nil, errors.New("resp is nil") } + _ = req.Body.Close() _ = c.Request.Body.Close() return resp, nil diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index ceed39a2..9c879399 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -2,13 +2,14 @@ package aws import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel/claude" relaycommon "one-api/relay/common" "one-api/setting/model_setting" + + "github.com/gin-gonic/gin" ) const ( @@ -74,6 +75,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 37196fd8..64c7b747 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -11,6 +11,8 @@ var awsModelIDMap = map[string]string{ "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", + "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -41,6 +43,16 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ }, "anthropic.claude-3-7-sonnet-20250219-v1:0": { "us": true, + "ap": true, + "eu": true, + }, + "anthropic.claude-sonnet-4-20250514-v1:0": { + "us": true, + "ap": true, + "eu": true, + }, + "anthropic.claude-opus-4-20250514-v1:0": { + "us": true, }, } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index eecb0bac..396c31ab 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -3,7 +3,6 @@ package baidu import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -130,6 +131,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return baiduEmbeddingRequest, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index ec7936dc..2b8a52a2 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -3,13 +3,15 @@ package baidu_v2 import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -48,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + if strings.HasSuffix(info.UpstreamModelName, "-search") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search") + request.Model = info.UpstreamModelName + toMap := request.ToMap() + toMap["web_search"] = map[string]any{ + "enable": true, + "enable_citation": true, + "enable_trace": true, + "enable_status": false, + } + return toMap, nil + } return request, nil } @@ -60,6 +74,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 6d65d6d4..8389b9f1 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -3,7 +3,6 @@ package claude import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( relaycommon "one-api/relay/common" "one-api/setting/model_setting" "strings" + + "github.com/gin-gonic/gin" ) const ( @@ -37,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { - if strings.HasPrefix(info.UpstreamModelName, "claude-3") { - a.RequestMode = RequestModeMessage - } else { + if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") { a.RequestMode = RequestModeCompletion + } else { + a.RequestMode = RequestModeMessage } } @@ -84,6 +85,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go index d7e0c8e3..e0e3c421 100644 --- a/relay/channel/claude/constants.go +++ b/relay/channel/claude/constants.go @@ -13,6 +13,10 @@ var ModelList = []string{ "claude-3-5-sonnet-20241022", "claude-3-7-sonnet-20250219", "claude-3-7-sonnet-20250219-thinking", + "claude-sonnet-4-20250514", + "claude-sonnet-4-20250514-thinking", + "claude-opus-4-20250514", + "claude-opus-4-20250514-thinking", } var ChannelName = "claude" diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 3d5a5a8a..06f4ca34 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -55,6 +55,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 53a357ad..a93b10f6 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -3,13 +3,14 @@ package cohere import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -52,6 +53,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return requestOpenAI2Cohere(*request), nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go new file mode 100644 index 00000000..80441a51 --- /dev/null +++ b/relay/channel/coze/adaptor.go @@ -0,0 +1,132 @@ +package coze + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/common" + "time" + + "github.com/gin-gonic/gin" +) + +type Adaptor struct { +} + +// ConvertAudioRequest implements channel.Adaptor. +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} + +// ConvertClaudeRequest implements channel.Adaptor. +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertEmbeddingRequest implements channel.Adaptor. +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements channel.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertOpenAIRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return convertCozeChatRequest(c, *request), nil +} + +// ConvertOpenAIResponsesRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertRerankRequest implements channel.Adaptor. +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// DoRequest implements channel.Adaptor. +func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + if info.IsStream { + return channel.DoApiRequest(a, c, info, requestBody) + } + // 首先发送创建消息请求,成功后再发送获取消息请求 + // 发送创建消息请求 + resp, err := channel.DoApiRequest(a, c, info, requestBody) + if err != nil { + return nil, err + } + // 解析 resp + var cozeResponse CozeChatResponse + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(respBody, &cozeResponse) + if cozeResponse.Code != 0 { + return nil, errors.New(cozeResponse.Msg) + } + c.Set("coze_conversation_id", cozeResponse.Data.ConversationId) + c.Set("coze_chat_id", cozeResponse.Data.Id) + // 轮询检查消息是否完成 + for { + err, isComplete := checkIfChatComplete(a, c, info) + if err != nil { + return nil, err + } else { + if isComplete { + break + } + } + time.Sleep(time.Second * 1) + } + // 发送获取消息请求 + return getChatDetail(a, c, info) +} + +// DoResponse implements channel.Adaptor. +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = cozeChatStreamHandler(c, resp, info) + } else { + err, usage = cozeChatHandler(c, resp, info) + } + return +} + +// GetChannelName implements channel.Adaptor. +func (a *Adaptor) GetChannelName() string { + return ChannelName +} + +// GetModelList implements channel.Adaptor. +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +// GetRequestURL implements channel.Adaptor. +func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil +} + +// Init implements channel.Adaptor. +func (a *Adaptor) Init(info *common.RelayInfo) { + +} + +// SetupRequestHeader implements channel.Adaptor. +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go new file mode 100644 index 00000000..873ffe24 --- /dev/null +++ b/relay/channel/coze/constants.go @@ -0,0 +1,30 @@ +package coze + +var ModelList = []string{ + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "Baichuan4", + "abab6.5s-chat-pro", + "glm-4-0520", + "qwen-max", + "deepseek-r1", + "deepseek-v3", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-7b", + "step-1v-8k", + "step-1.5v-mini", + "Doubao-pro-32k", + "Doubao-pro-256k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-vision-lite-32k", + "Doubao-vision-pro-32k", + "Doubao-1.5-pro-vision-32k", + "Doubao-1.5-lite-32k", + "Doubao-1.5-pro-32k", + "Doubao-1.5-thinking-pro", + "Doubao-1.5-pro-256k", +} + +var ChannelName = "coze" diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go new file mode 100644 index 00000000..4e9afa23 --- /dev/null +++ b/relay/channel/coze/dto.go @@ -0,0 +1,78 @@ +package coze + +import "encoding/json" + +type CozeError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type CozeEnterMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ContentType string `json:"content_type,omitempty"` +} + +type CozeChatRequest struct { + BotId string `json:"bot_id"` + UserId string `json:"user_id"` + AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"` + Stream bool `json:"stream,omitempty"` + CustomVariables json.RawMessage `json:"custom_variables,omitempty"` + AutoSaveHistory bool `json:"auto_save_history,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ExtraParams json.RawMessage `json:"extra_params,omitempty"` + ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type CozeChatResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data CozeChatResponseData `json:"data"` +} + +type CozeChatResponseData struct { + Id string `json:"id"` + ConversationId string `json:"conversation_id"` + BotId string `json:"bot_id"` + CreatedAt int64 `json:"created_at"` + LastError CozeError `json:"last_error"` + Status string `json:"status"` + Usage CozeChatUsage `json:"usage"` +} + +type CozeChatUsage struct { + TokenCount int `json:"token_count"` + OutputCount int `json:"output_count"` + InputCount int `json:"input_count"` +} + +type CozeChatDetailResponse struct { + Data []CozeChatV3MessageDetail `json:"data"` + Code int `json:"code"` + Msg string `json:"msg"` + Detail CozeResponseDetail `json:"detail"` +} + +type CozeChatV3MessageDetail struct { + Id string `json:"id"` + Role string `json:"role"` + Type string `json:"type"` + BotId string `json:"bot_id"` + ChatId string `json:"chat_id"` + Content json.RawMessage `json:"content"` + MetaData json.RawMessage `json:"meta_data"` + CreatedAt int64 `json:"created_at"` + SectionId string `json:"section_id"` + UpdatedAt int64 `json:"updated_at"` + ContentType string `json:"content_type"` + ConversationId string `json:"conversation_id"` + ReasoningContent string `json:"reasoning_content"` +} + +type CozeResponseDetail struct { + Logid string `json:"logid"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go new file mode 100644 index 00000000..6db40213 --- /dev/null +++ b/relay/channel/coze/relay-coze.go @@ -0,0 +1,300 @@ +package coze + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" + + "github.com/gin-gonic/gin" +) + +func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest { + var messages []CozeEnterMessage + // 将 request的messages的role为user的content转换为CozeMessage + for _, message := range request.Messages { + if message.Role == "user" { + messages = append(messages, CozeEnterMessage{ + Role: "user", + Content: message.Content, + // TODO: support more content type + ContentType: "text", + }) + } + } + user := request.User + if user == "" { + user = helper.GetResponseID(c) + } + cozeRequest := &CozeChatRequest{ + BotId: c.GetString("bot_id"), + UserId: user, + AdditionalMessages: messages, + Stream: request.Stream, + } + return cozeRequest +} + +func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // convert coze response to openai response + var response dto.TextResponse + var cozeResponse CozeChatDetailResponse + response.Model = info.UpstreamModelName + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cozeResponse.Code != 0 { + return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil + } + // 从上下文获取 usage + var usage dto.Usage + usage.PromptTokens = c.GetInt("coze_input_count") + usage.CompletionTokens = c.GetInt("coze_output_count") + usage.TotalTokens = c.GetInt("coze_token_count") + response.Usage = usage + response.Id = helper.GetResponseID(c) + + var responseContent json.RawMessage + for _, data := range cozeResponse.Data { + if data.Type == "answer" { + responseContent = data.Content + response.Created = data.CreatedAt + } + } + // 添加 response.Choices + response.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Role: "assistant", Content: responseContent}, + FinishReason: "stop", + }, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + return nil, &usage +} + +func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + var responseText string + + var currentEvent string + var currentData string + var usage dto.Usage + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + if currentEvent != "" && currentData != "" { + // handle last event + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + currentEvent = "" + currentData = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(line[6:]) + continue + } + + if strings.HasPrefix(line, "data:") { + currentData = strings.TrimSpace(line[5:]) + continue + } + } + + // Last event + if currentEvent != "" && currentData != "" { + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + } + + if err := scanner.Err(); err != nil { + return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + } + helper.Done(c) + + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return nil, &usage +} + +func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { + switch event { + case "conversation.chat.completed": + // 将 data 解析为 CozeChatResponseData + var chatData CozeChatResponseData + err := json.Unmarshal([]byte(data), &chatData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + usage.PromptTokens = chatData.Usage.InputCount + usage.CompletionTokens = chatData.Usage.OutputCount + usage.TotalTokens = chatData.Usage.TokenCount + + finishReason := "stop" + stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason) + helper.ObjectData(c, stopResponse) + + case "conversation.message.delta": + // 将 data 解析为 CozeChatV3MessageDetail + var messageData CozeChatV3MessageDetail + err := json.Unmarshal([]byte(data), &messageData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + var content string + err = json.Unmarshal(messageData.Content, &content) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + *responseText += content + + openaiResponse := dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + } + + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: 0, + } + choice.Delta.SetContentString(content) + openaiResponse.Choices = append(openaiResponse.Choices, choice) + + helper.ObjectData(c, openaiResponse) + + case "error": + var errorData CozeError + err := json.Unmarshal([]byte(data), &errorData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + } +} + +func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + // 将 conversationId和chatId作为参数发送get请求 + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return err, false + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return err, false + } + + resp, err := doRequest(req, info) // 调用 doRequest + if err != nil { + return err, false + } + if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic + return fmt.Errorf("resp is nil"), false + } + defer resp.Body.Close() // 确保响应体被关闭 + + // 解析 resp 到 CozeChatResponse + var cozeResponse CozeChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body failed: %w", err), false + } + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return fmt.Errorf("unmarshal response body failed: %w", err), false + } + if cozeResponse.Data.Status == "completed" { + // 在上下文设置 usage + c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount) + c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount) + c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount) + return nil, true + } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" { + return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false + } else { + return nil, false + } +} + +func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(req, info) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + +func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { + var client *http.Client + var err error // 声明 err 变量 + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + resp, err := client.Do(req) + if err != nil { // 增加对 client.Do(req) 返回错误的检查 + return nil, fmt.Errorf("client.Do failed: %w", err) + } + // _ = resp.Body.Close() + return resp, nil +} diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index f6e910e8..76e7fa8d 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -3,7 +3,6 @@ package deepseek import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -71,6 +72,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index dddcb994..51dbee71 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -3,12 +3,13 @@ package dify import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) const ( @@ -86,6 +87,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index feaed8f4..e6f66d5f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/service" "one-api/setting/model_setting" "strings" @@ -155,11 +156,24 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return geminiRequest, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.RelayMode == constant.RelayModeGemini { + if info.IsStream { + return GeminiTextGenerationStreamHandler(c, resp, info) + } else { + return GeminiTextGenerationHandler(c, resp, info) + } + } + if strings.HasPrefix(info.UpstreamModelName, "imagen") { return GeminiImageHandler(c, resp, info) } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 5d5c1287..a0e38cb4 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -2,10 +2,10 @@ package gemini type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` - SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"` Tools []GeminiChatTool `json:"tools,omitempty"` - SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"` + SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` } type GeminiThinkingConfig struct { @@ -54,6 +54,7 @@ type GeminiFileData struct { type GeminiPart struct { Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go new file mode 100644 index 00000000..c055e299 --- /dev/null +++ b/relay/channel/gemini/relay-gemini-native.go @@ -0,0 +1,128 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + + "github.com/gin-gonic/gin" +) + +func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + // 读取响应体 + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + + if common.DebugEnabled { + println(string(responseBody)) + } + + // 解析为 Gemini 原生响应格式 + var geminiResponse GeminiChatResponse + err = common.DecodeJson(responseBody, &geminiResponse) + if err != nil { + return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + // 检查是否有候选响应 + if len(geminiResponse.Candidates) == 0 { + return nil, &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + } + } + + // 计算使用量(基于 UsageMetadata) + usage := dto.Usage{ + PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, + } + + // 直接返回 Gemini 原生格式的 JSON 响应 + jsonResponse, err := json.Marshal(geminiResponse) + if err != nil { + return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) + } + + // 设置响应头并写入响应 + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) + } + + return &usage, nil +} + +func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + var usage = &dto.Usage{} + var imageCount int + + helper.SetEventStreamHeaders(c) + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + var geminiResponse GeminiChatResponse + err := common.DecodeJsonStr(data, &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + return false + } + + // 统计图片数量 + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && part.InlineData.MimeType != "" { + imageCount++ + } + } + } + + // 更新使用量统计 + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount + } + + // 直接发送 GeminiChatResponse 响应 + err = helper.ObjectData(c, geminiResponse) + if err != nil { + common.LogError(c, err.Error()) + } + + return true + }) + + if imageCount != 0 { + if usage.CompletionTokens == 0 { + usage.CompletionTokens = imageCount * 258 + } + } + + // 计算最终使用量 + usage.PromptTokensDetails.TextTokens = usage.PromptTokens + usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + + // 结束流式响应 + helper.Done(c) + + return usage, nil +} diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index dbe65528..bf1ece57 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -18,6 +18,24 @@ import ( "github.com/gin-gonic/gin" ) +var geminiSupportedMimeTypes = map[string]bool{ + "application/pdf": true, + "audio/mpeg": true, + "audio/mp3": true, + "audio/wav": true, + "image/png": true, + "image/jpeg": true, + "text/plain": true, + "video/mov": true, + "video/mpeg": true, + "video/mp4": true, + "video/mpg": true, + "video/avi": true, + "video/wmv": true, + "video/mpegps": true, + "video/flv": true, +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { @@ -39,15 +57,22 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if strings.HasSuffix(info.OriginModelName, "-thinking") { - budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) - if budgetTokens == 0 || budgetTokens > 24576 { - budgetTokens = 24576 - } - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ - ThinkingBudget: common.GetPointer(int(budgetTokens)), - IncludeThoughts: true, - } + if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 如果模型名以 gemini-2.5-pro 开头,不设置 ThinkingBudget + if strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") { + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + IncludeThoughts: true, + } + } else { + budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) + if budgetTokens == 0 || budgetTokens > 24576 { + budgetTokens = 24576 + } + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(int(budgetTokens)), + IncludeThoughts: true, + } + } } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ ThinkingBudget: common.GetPointer(0), @@ -208,14 +233,20 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } // 判断是否是url if strings.HasPrefix(part.GetImageMedia().Url, "http") { - // 是url,获取图片的类型和base64编码的数据 + // 是url,获取文件的类型和base64编码的数据 fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url) if err != nil { - return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) + return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err) } + + // 校验 MimeType 是否在 Gemini 支持的白名单中 + if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok { + return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList()) + } + parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ - MimeType: fileData.MimeType, + MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义 Data: fileData.Base64Data, }, }) @@ -284,100 +315,126 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon return &geminiRequest, nil } +// Helper function to get a list of supported MIME types for error messages +func getSupportedMimeTypesList() []string { + keys := make([]string, 0, len(geminiSupportedMimeTypes)) + for k := range geminiSupportedMimeTypes { + keys = append(keys, k) + } + return keys +} + // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters. func cleanFunctionParameters(params interface{}) interface{} { if params == nil { return nil } - paramMap, ok := params.(map[string]interface{}) - if !ok { - // Not a map, return as is (e.g., could be an array or primitive) - return params - } + switch v := params.(type) { + case map[string]interface{}: + // Create a copy to avoid modifying the original + cleanedMap := make(map[string]interface{}) + for k, val := range v { + cleanedMap[k] = val + } - // Create a copy to avoid modifying the original - cleanedMap := make(map[string]interface{}) - for k, v := range paramMap { - cleanedMap[k] = v - } + // Remove unsupported root-level fields + delete(cleanedMap, "default") + delete(cleanedMap, "exclusiveMaximum") + delete(cleanedMap, "exclusiveMinimum") + delete(cleanedMap, "$schema") + delete(cleanedMap, "additionalProperties") - // Remove unsupported root-level fields - delete(cleanedMap, "default") - delete(cleanedMap, "exclusiveMaximum") - delete(cleanedMap, "exclusiveMinimum") - delete(cleanedMap, "$schema") - delete(cleanedMap, "additionalProperties") - - // Clean properties - if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil { - cleanedProps := make(map[string]interface{}) - for propName, propValue := range props { - propMap, ok := propValue.(map[string]interface{}) - if !ok { - cleanedProps[propName] = propValue // Keep non-map properties - continue - } - - // Create a copy of the property map - cleanedPropMap := make(map[string]interface{}) - for k, v := range propMap { - cleanedPropMap[k] = v - } - - // Remove unsupported fields - delete(cleanedPropMap, "default") - delete(cleanedPropMap, "exclusiveMaximum") - delete(cleanedPropMap, "exclusiveMinimum") - delete(cleanedPropMap, "$schema") - delete(cleanedPropMap, "additionalProperties") - - // Check and clean 'format' for string types - if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" { - if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists { - if formatValue != "enum" && formatValue != "date-time" { - delete(cleanedPropMap, "format") - } + // Check and clean 'format' for string types + if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" { + if formatValue, formatExists := cleanedMap["format"].(string); formatExists { + if formatValue != "enum" && formatValue != "date-time" { + delete(cleanedMap, "format") } } + } - // Recursively clean nested properties within this property if it's an object/array - // Check the type before recursing - if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") { - cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap) - } else { - cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing + // Clean properties + if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil { + cleanedProps := make(map[string]interface{}) + for propName, propValue := range props { + cleanedProps[propName] = cleanFunctionParameters(propValue) } - + cleanedMap["properties"] = cleanedProps } - cleanedMap["properties"] = cleanedProps - } - // Recursively clean items in arrays if needed (e.g., type: array, items: { ... }) - if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil { - cleanedMap["items"] = cleanFunctionParameters(items) - } - // Also handle items if it's an array of schemas - if itemsArray, ok := cleanedMap["items"].([]interface{}); ok { - cleanedItemsArray := make([]interface{}, len(itemsArray)) - for i, item := range itemsArray { - cleanedItemsArray[i] = cleanFunctionParameters(item) + // Recursively clean items in arrays + if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil { + cleanedMap["items"] = cleanFunctionParameters(items) } - cleanedMap["items"] = cleanedItemsArray - } - - // Recursively clean other schema composition keywords if necessary - for _, field := range []string{"allOf", "anyOf", "oneOf"} { - if nested, ok := cleanedMap[field].([]interface{}); ok { - cleanedNested := make([]interface{}, len(nested)) - for i, item := range nested { - cleanedNested[i] = cleanFunctionParameters(item) + // Also handle items if it's an array of schemas + if itemsArray, ok := cleanedMap["items"].([]interface{}); ok { + cleanedItemsArray := make([]interface{}, len(itemsArray)) + for i, item := range itemsArray { + cleanedItemsArray[i] = cleanFunctionParameters(item) } - cleanedMap[field] = cleanedNested + cleanedMap["items"] = cleanedItemsArray } - } - return cleanedMap + // Recursively clean other schema composition keywords + for _, field := range []string{"allOf", "anyOf", "oneOf"} { + if nested, ok := cleanedMap[field].([]interface{}); ok { + cleanedNested := make([]interface{}, len(nested)) + for i, item := range nested { + cleanedNested[i] = cleanFunctionParameters(item) + } + cleanedMap[field] = cleanedNested + } + } + + // Recursively clean patternProperties + if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok { + cleanedPatternProps := make(map[string]interface{}) + for pattern, schema := range patternProps { + cleanedPatternProps[pattern] = cleanFunctionParameters(schema) + } + cleanedMap["patternProperties"] = cleanedPatternProps + } + + // Recursively clean definitions + if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok { + cleanedDefinitions := make(map[string]interface{}) + for defName, defSchema := range definitions { + cleanedDefinitions[defName] = cleanFunctionParameters(defSchema) + } + cleanedMap["definitions"] = cleanedDefinitions + } + + // Recursively clean $defs (newer JSON Schema draft) + if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok { + cleanedDefs := make(map[string]interface{}) + for defName, defSchema := range defs { + cleanedDefs[defName] = cleanFunctionParameters(defSchema) + } + cleanedMap["$defs"] = cleanedDefs + } + + // Clean conditional keywords + for _, field := range []string{"if", "then", "else", "not"} { + if nested, ok := cleanedMap[field]; ok { + cleanedMap[field] = cleanFunctionParameters(nested) + } + } + + return cleanedMap + + case []interface{}: + // Handle arrays of schemas + cleanedArray := make([]interface{}, len(v)) + for i, item := range v { + cleanedArray[i] = cleanFunctionParameters(item) + } + return cleanedArray + + default: + // Not a map or array, return as is (e.g., could be a primitive) + return params + } } func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} { @@ -391,6 +448,7 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac } // 删除所有的title字段 delete(v, "title") + delete(v, "$schema") // 如果type不为object和array,则直接返回 if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") { return schema @@ -538,6 +596,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp if call := getResponseToolCall(&part); call != nil { toolCalls = append(toolCalls, *call) } + } else if part.Thought { + choice.Message.ReasoningContent = part.Text } else { if part.ExecutableCode != nil { texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```") @@ -555,7 +615,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp choice.Message.SetToolCalls(toolCalls) isToolCall = true } - choice.Message.SetStringContent(strings.Join(texts, "\n")) } @@ -595,6 +654,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C } var texts []string isTools := false + isThought := false if candidate.FinishReason != nil { // p := GeminiConvertFinishReason(*candidate.FinishReason) switch *candidate.FinishReason { @@ -619,6 +679,9 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C call.SetIndex(len(choice.Delta.ToolCalls)) choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call) } + } else if part.Thought { + isThought = true + texts = append(texts, part.Text) } else { if part.ExecutableCode != nil { texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n") @@ -631,7 +694,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C } } } - choice.Delta.SetContentString(strings.Join(texts, "\n")) + if isThought { + choice.Delta.SetReasoningContent(strings.Join(texts, "\n")) + } else { + choice.Delta.SetContentString(strings.Join(texts, "\n")) + } if isTools { choice.FinishReason = &constant.FinishReasonToolCalls } @@ -715,8 +782,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + if common.DebugEnabled { + println(string(responseBody)) + } var geminiResponse GeminiChatResponse - err = json.Unmarshal(responseBody, &geminiResponse) + err = common.DecodeJson(responseBody, &geminiResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 3faac243..85b6a83f 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -3,7 +3,6 @@ package jina import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/common_handler" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -55,6 +56,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 82c82496..44f57e61 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -2,13 +2,14 @@ package mistral import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -59,6 +60,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 304351fd..b889f225 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -3,7 +3,6 @@ package mokaai import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -74,6 +75,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 39e408ab..18069311 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -2,7 +2,6 @@ package ollama import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -10,6 +9,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -64,6 +65,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return requestOpenAI2Embeddings(request), nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 502cee69..f0cf073f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -8,6 +8,7 @@ import ( "io" "mime/multipart" "net/http" + "net/textproto" "one-api/common" constant2 "one-api/constant" "one-api/dto" @@ -26,7 +27,6 @@ import ( "strings" "github.com/gin-gonic/gin" - "net/textproto" ) type Adaptor struct { @@ -89,7 +89,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") model_ := info.UpstreamModelName - model_ = strings.Replace(model_, ".", "", -1) + // 2025年5月10日后创建的渠道不移除. + if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime { + model_ = strings.Replace(model_, ".", "", -1) + } // https://github.com/songquanpeng/one-api/issues/67 requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) if info.RelayMode == constant.RelayModeRealtime { @@ -169,7 +172,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn info.UpstreamModelName = request.Model // o系列模型developer适配(o1-mini除外) - if !strings.HasPrefix(request.Model, "o1-mini") { + if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") { //修改第一个Message的内容,将system改为developer if len(request.Messages) > 0 && request.Messages[0].Role == "system" { request.Messages[0].Role = "developer" @@ -380,6 +383,21 @@ func detectImageMimeType(filename string) string { } } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // 模型后缀转换 reasoning effort + if strings.HasSuffix(request.Model, "-high") { + request.Reasoning.Effort = "high" + request.Model = strings.TrimSuffix(request.Model, "-high") + } else if strings.HasSuffix(request.Model, "-low") { + request.Reasoning.Effort = "low" + request.Model = strings.TrimSuffix(request.Model, "-low") + } else if strings.HasSuffix(request.Model, "-medium") { + request.Reasoning.Effort = "medium" + request.Model = strings.TrimSuffix(request.Model, "-medium") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation || @@ -406,6 +424,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = OpenaiHandlerWithUsage(c, resp, info) case constant.RelayModeRerank: err, usage = common_handler.RerankHandler(c, info, resp) + case constant.RelayModeResponses: + if info.IsStream { + err, usage = OaiResponsesStreamHandler(c, resp, info) + } else { + err, usage = OaiResponsesHandler(c, resp, info) + } default: if info.IsStream { err, usage = OaiStreamHandler(c, resp, info) diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index e7ba2e7b..a068c544 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } } } + +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { + if data == "" { + return + } + helper.ResponseChunkData(c, streamResponse, data) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b9ed94e2..2e3d8df1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -215,10 +215,35 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } + + forceFormat := false + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { + completionTokens := 0 + for _, choice := range simpleResponse.Choices { + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + completionTokens += ctkm + } + simpleResponse.Usage = dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: completionTokens, + TotalTokens: info.PromptTokens + completionTokens, + } + } switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: - break + if forceFormat { + responseBody, err = json.Marshal(simpleResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + } else { + break + } case relaycommon.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := json.Marshal(claudeResp) @@ -244,52 +269,29 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI common.SysError("error copying response body: " + err.Error()) } resp.Body.Close() - if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { - completionTokens := 0 - for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) - completionTokens += ctkm - } - simpleResponse.Usage = dto.Usage{ - PromptTokens: info.PromptTokens, - CompletionTokens: completionTokens, - TotalTokens: info.PromptTokens + completionTokens, - } - } return nil, &simpleResponse.Usage } func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. + // the status code has been judged before, if there is a body reading failure, + // it should be regarded as a non-recoverable error, so it should not return err for external retry. + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, + // and can be terminated directly. + defer resp.Body.Close() + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.TotalTokens = info.PromptTokens for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + c.Writer.WriteHeaderNow() + _, err := io.Copy(c.Writer, resp.Body) if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + common.LogError(c, err.Error()) } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - usage := &dto.Usage{} - usage.PromptTokens = info.PromptTokens - usage.TotalTokens = info.PromptTokens return nil, usage } diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go new file mode 100644 index 00000000..1d1e060e --- /dev/null +++ b/relay/channel/openai/relay_responses.go @@ -0,0 +1,119 @@ +package openai + +import ( + "bytes" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" + + "github.com/gin-gonic/gin" +) + +func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // read response body + var responsesResponse dto.OpenAIResponsesResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = common.DecodeJson(responseBody, &responsesResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if responsesResponse.Error != nil { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: responsesResponse.Error.Message, + Type: "openai_error", + Code: responsesResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + // copy response body + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + common.SysError("error copying response body: " + err.Error()) + } + resp.Body.Close() + // compute usage + usage := dto.Usage{} + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + // 解析 Tools 用量 + for _, tool := range responsesResponse.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++ + } + return nil, &usage +} + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var responseTextBuilder strings.Builder + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + case dto.ResponsesOutputTypeItemDone: + // 函数调用处理 + if streamResponse.Item != nil { + switch streamResponse.Item.Type { + case dto.BuildInCallWebSearchCall: + info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++ + } + } + } + } + return true + }) + + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } + + return nil, usage +} diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go new file mode 100644 index 00000000..607f495b --- /dev/null +++ b/relay/channel/openrouter/dto.go @@ -0,0 +1,9 @@ +package openrouter + +type RequestReasoning struct { + // One of the following (not both): + Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style) + MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style) + // Optional: Default is false. All models support this. + Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index f0220f4f..3a06e7ee 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -3,13 +3,14 @@ package palm import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -60,6 +61,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 5727cac7..ca206503 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -3,13 +3,14 @@ package perplexity import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -63,6 +64,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index cf38c15e..89236ea3 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -3,7 +3,6 @@ package siliconflow import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -58,6 +59,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index f2b51ee9..44718a25 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -3,7 +3,6 @@ package tencent import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -13,6 +12,8 @@ import ( "one-api/service" "strconv" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -84,6 +85,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 77f29620..31f84abf 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -13,7 +12,11 @@ import ( "one-api/relay/channel/gemini" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" + "one-api/setting/model_setting" "strings" + + "github.com/gin-gonic/gin" ) const ( @@ -29,6 +32,8 @@ var claudeModelMap = map[string]string{ "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219", + "claude-sonnet-4-20250514": "claude-sonnet-4@20250514", + "claude-opus-4-20250514": "claude-opus-4@20250514", } const anthropicVersion = "vertex-2023-10-16" @@ -77,19 +82,37 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { a.AccountCredentials = *adc suffix := "" if a.RequestMode == RequestModeGemini { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { + // suffix -thinking and -nothinking + if strings.HasSuffix(info.OriginModelName, "-thinking") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") + } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") + } + } + if info.IsStream { suffix = "streamGenerateContent?alt=sse" } else { suffix = "generateContent" } - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", - region, - adc.ProjectID, - region, - info.UpstreamModelName, - suffix, - ), nil + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", + adc.ProjectID, + info.UpstreamModelName, + suffix, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", + region, + adc.ProjectID, + region, + info.UpstreamModelName, + suffix, + ), nil + } } else if a.RequestMode == RequestModeClaude { if info.IsStream { suffix = "streamRawPredict?alt=sse" @@ -164,6 +187,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } @@ -174,7 +202,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: - err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatStreamHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OaiStreamHandler(c, resp, info) } @@ -183,7 +215,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom case RequestModeClaude: err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) case RequestModeGemini: - err, usage = gemini.GeminiChatHandler(c, resp, info) + if info.RelayMode == constant.RelayModeGemini { + usage, err = gemini.GeminiTextGenerationHandler(c, resp, info) + } else { + err, usage = gemini.GeminiChatHandler(c, resp, info) + } case RequestModeLlama: err, usage = openai.OpenaiHandler(c, resp, info) } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 277285b7..a4a48ee9 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -3,7 +3,6 @@ package volcengine import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/constant" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -71,6 +72,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 669b8c68..b5896415 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -2,14 +2,17 @@ package xai import ( "errors" - "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "strings" + + "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -27,15 +30,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - request.Size = "" - return request, nil + xaiRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + N: request.N, + ResponseFormat: request.ResponseFormat, + } + return xaiRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -78,20 +86,26 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not available") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = xAIStreamHandler(c, resp, info) - } else { - err, usage = xAIHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) + default: + if info.IsStream { + err, usage = xAIStreamHandler(c, resp, info) + } else { + err, usage = xAIHandler(c, resp, info) + } } - //if _, ok := usage.(*dto.Usage); ok && usage != nil { - // usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens - //} - return } diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go index 7036d5f1..b8098475 100644 --- a/relay/channel/xai/dto.go +++ b/relay/channel/xai/dto.go @@ -12,3 +12,16 @@ type ChatCompletionResponse struct { Usage *dto.Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` } + +// quality, size or style are not supported by xAI API at the moment. +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + // Size string `json:"size,omitempty"` + // Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + // Style string `json:"style,omitempty"` + // User string `json:"user,omitempty"` + // ExtraFields json.RawMessage `json:"extra_fields,omitempty"` +} \ No newline at end of file diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9521bb47..7591e0e7 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -2,7 +2,6 @@ package xunfei import ( "errors" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -10,6 +9,8 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -61,6 +62,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 04369001..b4d8fb30 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -3,12 +3,13 @@ package zhipu import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -71,6 +72,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request return channel.DoApiRequest(a, c, info, requestBody) } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index e13a7ad2..222cdff8 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -3,7 +3,6 @@ package zhipu_4v import ( "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" @@ -11,6 +10,8 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + + "github.com/gin-gonic/gin" ) type Adaptor struct { @@ -70,6 +71,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return request, nil } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + // TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index a07ec316..f4fc3c1e 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -36,6 +36,7 @@ type ClaudeConvertInfo struct { const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" ) type RerankerInfo struct { @@ -43,6 +44,16 @@ type RerankerInfo struct { ReturnDocuments bool } +type BuildInToolInfo struct { + ToolName string + CallCount int + SearchContextSize string +} + +type ResponsesUsageInfo struct { + BuiltInTools map[string]*BuildInToolInfo +} + type RelayInfo struct { ChannelType int ChannelId int @@ -87,9 +98,11 @@ type RelayInfo struct { UserQuota int RelayFormat string SendResponseCount int + ChannelCreateTime int64 ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo + *ResponsesUsageInfo } // 定义支持流式选项的通道类型 @@ -103,6 +116,8 @@ var streamSupportedChannels = map[int]bool{ common.ChannelTypeVolcEngine: true, common.ChannelTypeOllama: true, common.ChannelTypeXai: true, + common.ChannelTypeDeepSeek: true, + common.ChannelTypeBaiduV2: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { @@ -134,6 +149,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { + info := GenRelayInfo(c) + info.RelayMode = relayconstant.RelayModeResponses + info.ResponsesUsageInfo = &ResponsesUsageInfo{ + BuiltInTools: make(map[string]*BuildInToolInfo), + } + if len(req.Tools) > 0 { + for _, tool := range req.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{ + ToolName: tool.Type, + CallCount: 0, + } + switch tool.Type { + case dto.BuildInToolWebSearchPreview: + if tool.SearchContextSize == "" { + tool.SearchContextSize = "medium" + } + info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize + } + } + } + info.IsStream = req.Stream + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -170,14 +210,15 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { OriginModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"), //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelSetting: channelSetting, + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, @@ -200,6 +241,10 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } + // responses 模式不支持 StreamOptions + if relayconstant.RelayModeResponses == info.RelayMode { + info.SupportStreamOptions = false + } return info } diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index fef38f23..3f1ecd78 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -33,6 +33,7 @@ const ( APITypeOpenRouter APITypeXinference APITypeXai + APITypeCoze APITypeDummy // this one is only for count, do not add any channel after this ) @@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeXinference case common.ChannelTypeXai: apiType = APITypeXai + case common.ChannelTypeCoze: + apiType = APITypeCoze } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index e2d51098..f22a20bd 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -40,7 +40,11 @@ const ( RelayModeRerank + RelayModeResponses + RelayModeRealtime + + RelayModeGemini ) func Path2RelayMode(path string) int { @@ -61,6 +65,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeImagesEdits } else if strings.HasPrefix(path, "/v1/edits") { relayMode = RelayModeEdits + } else if strings.HasPrefix(path, "/v1/responses") { + relayMode = RelayModeResponses } else if strings.HasPrefix(path, "/v1/audio/speech") { relayMode = RelayModeAudioSpeech } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { @@ -71,6 +77,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeRerank } else if strings.HasPrefix(path, "/v1/realtime") { relayMode = RelayModeRealtime + } else if strings.HasPrefix(path, "/v1beta/models") { + relayMode = RelayModeGemini } return relayMode } diff --git a/relay/helper/common.go b/relay/helper/common.go index ebfb6d58..35d983f7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -12,11 +12,19 @@ import ( ) func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { @@ -43,6 +51,14 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { } } +func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) { + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)}) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r") diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index 948c5226..9bf67c03 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -2,9 +2,11 @@ package helper import ( "encoding/json" + "errors" "fmt" - "github.com/gin-gonic/gin" "one-api/relay/common" + + "github.com/gin-gonic/gin" ) func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { @@ -16,9 +18,36 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { if err != nil { return fmt.Errorf("unmarshal_model_mapping_failed") } - if modelMap[info.OriginModelName] != "" { - info.UpstreamModelName = modelMap[info.OriginModelName] - info.IsModelMapped = true + + // 支持链式模型重定向,最终使用链尾的模型 + currentModel := info.OriginModelName + visitedModels := map[string]bool{ + currentModel: true, + } + for { + if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" { + // 模型重定向循环检测,避免无限循环 + if visitedModels[mappedModel] { + if mappedModel == currentModel { + if currentModel == info.OriginModelName { + info.IsModelMapped = false + return nil + } else { + info.IsModelMapped = true + break + } + } + return errors.New("model_mapping_contains_cycle") + } + visitedModels[mappedModel] = true + currentModel = mappedModel + info.IsModelMapped = true + } else { + break + } + } + if info.IsModelMapped { + info.UpstreamModelName = currentModel } } return nil diff --git a/relay/helper/price.go b/relay/helper/price.go index 899c72b9..89efa1da 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -23,7 +23,7 @@ type PriceData struct { } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index abb98f42..c1bc0d6e 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -3,7 +3,6 @@ package helper import ( "bufio" "context" - "github.com/bytedance/gopkg/util/gopool" "io" "net/http" "one-api/common" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) @@ -32,7 +33,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer resp.Body.Close() streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second - if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") { + if strings.HasPrefix(info.UpstreamModelName, "o") { // twice timeout for thinking model streamingTimeout *= 2 } @@ -115,7 +116,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon } data = data[5:] data = strings.TrimLeft(data, " ") - data = strings.TrimSuffix(data, "\"") + data = strings.TrimSuffix(data, "\r") if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() writeMutex.Lock() // Lock before writing diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go new file mode 100644 index 00000000..93a2b7aa --- /dev/null +++ b/relay/relay-gemini.go @@ -0,0 +1,157 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/relay/channel/gemini" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting" + "strings" + + "github.com/gin-gonic/gin" +) + +func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { + request := &gemini.GeminiChatRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if len(request.Contents) == 0 { + return nil, errors.New("contents is required") + } + return request, nil +} + +// 流模式 +// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx +func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if c.Query("alt") == "sse" { + relayInfo.IsStream = true + } + + // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") { + // relayInfo.IsStream = true + // } +} + +func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { + var inputTexts []string + for _, content := range textRequest.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + if len(inputTexts) == 0 { + return nil, nil + } + + sensitiveWords, err := service.CheckSensitiveInput(inputTexts) + return sensitiveWords, err +} + +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { + // 计算输入 token 数量 + var inputTexts []string + for _, content := range req.Contents { + for _, part := range content.Parts { + if part.Text != "" { + inputTexts = append(inputTexts, part.Text) + } + } + } + + inputText := strings.Join(inputTexts, "\n") + inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + info.PromptTokens = inputTokens + return inputTokens, err +} + +func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { + req, err := getAndValidateGeminiRequest(c) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) + } + + relayInfo := relaycommon.GenRelayInfo(c) + + // 检查 Gemini 流式模式 + checkGeminiStreamMode(c, relayInfo) + + if setting.ShouldCheckPromptSensitive() { + sensitiveWords, err := checkGeminiInputSensitive(req) + if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) + return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + } + } + + // model mapped 模型映射 + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + } + + if value, exists := c.Get("prompt_tokens"); exists { + promptTokens := value.(int) + relayInfo.SetPromptTokens(promptTokens) + } else { + promptTokens, err := getGeminiInputTokens(req, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) + } + c.Set("prompt_tokens", promptTokens) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + + // pre consume quota + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + + adaptor.Init(relayInfo) + + requestBody, err := json.Marshal(req) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + + resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) + if err != nil { + common.LogError(c, "Do gemini request failed: "+err.Error()) + return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + } + + usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) + if openaiErr != nil { + return openaiErr + } + + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + return nil +} diff --git a/relay/relay-image.go b/relay/relay-image.go index 70219cc1..dc63cce8 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -41,91 +41,56 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. imageRequest.Quality = "standard" } } + if imageRequest.N == 0 { + imageRequest.N = 1 + } default: err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { return nil, err } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-3" + } + + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") } if imageRequest.Quality == "" { imageRequest.Quality = "standard" } - // N should between 1 and 10 - //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) - //} + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + } else if imageRequest.Model == "gpt-image-1" { + if imageRequest.Quality == "" { + imageRequest.Quality = "auto" + } + } + + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") + } + + if imageRequest.N == 0 { + imageRequest.N = 1 } } - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - if imageRequest.Prompt == "" { - return nil, errors.New("prompt is required") - } - if strings.Contains(imageRequest.Size, "×") { - return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - // x.ai grok-2-image not support size, quality or style - if imageRequest.Size == "empty" { - imageRequest.Size = "" - } - - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") - } - } else if imageRequest.Model == "dall-e-3" { - if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") - } - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - //if imageRequest.N != 1 { - // return nil, errors.New("n must be 1") - //} - } - // N should between 1 and 10 - //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) - //} if setting.ShouldCheckPromptSensitive() { words, err := service.CheckSensitiveInput(imageRequest.Prompt) if err != nil { @@ -229,6 +194,10 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { requestBody = bytes.NewBuffer(jsonData) } + if common.DebugEnabled { + println(fmt.Sprintf("image request body: %s", requestBody)) + } + statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, requestBody) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index a7018456..9d0a2077 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -32,7 +32,23 @@ func RelayMidjourneyImage(c *gin.Context) { }) return } - resp, err := http.Get(midjourneyTask.ImageUrl) + var httpClient *http.Client + if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil { + if proxy, ok := channel.GetSetting()["proxy"]; ok { + if proxyURL, ok := proxy.(string); ok && proxyURL != "" { + if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil { + c.JSON(400, gin.H{ + "error": "proxy_url_invalid", + }) + return + } + } + } + } + if httpClient == nil { + httpClient = service.GetHttpClient() + } + resp, err := httpClient.Get(midjourneyTask.ImageUrl) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "http_get_image_failed", diff --git a/relay/relay-responses.go b/relay/relay-responses.go new file mode 100644 index 00000000..fd3ddb5a --- /dev/null +++ b/relay/relay-responses.go @@ -0,0 +1,171 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/setting" + "one-api/setting/model_setting" + "strings" + + "github.com/gin-gonic/gin" +) + +func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { + request := &dto.OpenAIResponsesRequest{} + err := common.UnmarshalBodyReusable(c, request) + if err != nil { + return nil, err + } + if request.Model == "" { + return nil, errors.New("model is required") + } + if len(request.Input) == 0 { + return nil, errors.New("input is required") + } + return request, nil + +} + +func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { + sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) + return sensitiveWords, err +} + +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { + inputTokens, err := service.CountTokenInput(req.Input, req.Model) + info.PromptTokens = inputTokens + return inputTokens, err +} + +func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { + req, err := getAndValidateResponsesRequest(c) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) + return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) + } + + relayInfo := relaycommon.GenRelayInfoResponses(c, req) + + if setting.ShouldCheckPromptSensitive() { + sensitiveWords, err := checkInputSensitive(req, relayInfo) + if err != nil { + common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", "))) + return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest) + } + } + + err = helper.ModelMappedHelper(c, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) + } + req.Model = relayInfo.UpstreamModelName + if value, exists := c.Get("prompt_tokens"); exists { + promptTokens := value.(int) + relayInfo.SetPromptTokens(promptTokens) + } else { + promptTokens, err := getInputTokens(req, relayInfo) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) + } + c.Set("prompt_tokens", promptTokens) + } + + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens)) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) + } + // pre consume quota + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + var requestBody io.Reader + if model_setting.GetGlobalSettings().PassThroughRequestEnabled { + body, err := common.GetRequestBody(c) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(body) + } else { + convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError) + } + // apply param override + if len(relayInfo.ParamOverride) > 0 { + reqMap := make(map[string]interface{}) + err = json.Unmarshal(jsonData, &reqMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError) + } + for key, value := range relayInfo.ParamOverride { + reqMap[key] = value + } + jsonData, err = json.Marshal(reqMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError) + } + } + + if common.DebugEnabled { + println("requestBody: ", string(jsonData)) + } + requestBody = bytes.NewBuffer(jsonData) + } + + var httpResp *http.Response + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + + if resp != nil { + httpResp = resp.(*http.Response) + + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp, false) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + } + + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + + if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") { + service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + } else { + postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") + } + return nil +} diff --git a/relay/relay-text.go b/relay/relay-text.go index 4fdd435d..f1105907 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -18,6 +18,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/model_setting" + "one-api/setting/operation_setting" "strings" "time" @@ -46,6 +47,20 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) if textRequest.Model == "" { return nil, errors.New("model is required") } + if textRequest.WebSearchOptions != nil { + if textRequest.WebSearchOptions.SearchContextSize != "" { + validSizes := map[string]bool{ + "high": true, + "medium": true, + "low": true, + } + if !validSizes[textRequest.WebSearchOptions.SearchContextSize] { + return nil, errors.New("invalid search_context_size, must be one of: high, medium, low") + } + } else { + textRequest.WebSearchOptions.SearchContextSize = "medium" + } + } switch relayInfo.RelayMode { case relayconstant.RelayModeCompletions: if textRequest.Prompt == "" { @@ -75,6 +90,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) + if textRequest.WebSearchOptions != nil { + c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) + } + if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) @@ -193,6 +212,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } @@ -358,6 +378,45 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ratio := dModelRatio.Mul(dGroupRatio) + // openai web search 工具计费 + var dWebSearchQuota decimal.Decimal + var webSearchPrice float64 + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { + // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率) + webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize) + dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). + Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s", + webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) + } + } else if strings.HasSuffix(modelName, "search-preview") { + // search-preview 模型不支持 response api + searchContextSize := ctx.GetString("chat_completion_web_search_context_size") + if searchContextSize == "" { + searchContextSize = "medium" + } + webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize) + dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", + searchContextSize, dWebSearchQuota.String()) + } + // file search tool 计费 + var dFileSearchQuota decimal.Decimal + var fileSearchPrice float64 + if relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { + fileSearchPrice = operation_setting.GetFileSearchPricePerThousand() + dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice). + Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s", + fileSearchTool.CallCount, dFileSearchQuota.String()) + } + } + var quotaCalculateDecimal decimal.Decimal if !priceData.UsePrice { nonCachedTokens := dPromptTokens.Sub(dCacheTokens) @@ -380,6 +439,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } + // 添加 responses tools call 调用的配额 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -430,6 +492,26 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if !dWebSearchQuota.IsZero() { + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + other["web_search"] = true + other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_price"] = webSearchPrice + } + } else if strings.HasSuffix(modelName, "search-preview") { + other["web_search"] = true + other["web_search_call_count"] = 1 + other["web_search_price"] = webSearchPrice + } + } + if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists { + other["file_search"] = true + other["file_search_call_count"] = fileSearchTool.CallCount + other["file_search_price"] = fileSearchPrice + } + } model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8b4afcb3..7bf0da9f 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" + "one-api/relay/channel/coze" "one-api/relay/channel/deepseek" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypeXai: return &xai.Adaptor{} + case constant.APITypeCoze: + return &coze.Adaptor{} } return nil } diff --git a/router/relay-router.go b/router/relay-router.go index 85000beb..1115a491 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,10 +1,11 @@ package router import ( - "github.com/gin-gonic/gin" "one-api/controller" "one-api/middleware" "one-api/relay" + + "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -47,6 +48,7 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.POST("/audio/transcriptions", controller.Relay) httpRouter.POST("/audio/translations", controller.Relay) httpRouter.POST("/audio/speech", controller.Relay) + httpRouter.POST("/responses", controller.Relay) httpRouter.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) @@ -77,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) { relaySunoRouter.GET("/fetch/:id", controller.RelayTask) } + relayGeminiRouter := router.Group("/v1beta") + relayGeminiRouter.Use(middleware.TokenAuth()) + relayGeminiRouter.Use(middleware.ModelRequestRateLimit()) + relayGeminiRouter.Use(middleware.Distribute()) + { + // Gemini API 路径格式: /v1beta/models/{model_name}:{action} + relayGeminiRouter.POST("/models/*path", controller.Relay) + } } func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { diff --git a/service/cf_worker.go b/service/cf_worker.go index 40a1e294..ae6e1ffe 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -24,7 +24,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { if !setting.EnableWorker() { return nil, fmt.Errorf("worker not enabled") } - if !strings.HasPrefix(req.URL, "https") { + if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { return nil, fmt.Errorf("only support https url") } diff --git a/service/convert.go b/service/convert.go index cc462b40..cb964a46 100644 --- a/service/convert.go +++ b/service/convert.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/dto" + "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "strings" ) @@ -18,10 +19,24 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re Stream: claudeRequest.Stream, } + isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter + if claudeRequest.Thinking != nil { - if strings.HasSuffix(info.OriginModelName, "-thinking") && - !strings.HasSuffix(claudeRequest.Model, "-thinking") { - openAIRequest.Model = openAIRequest.Model + "-thinking" + if isOpenRouter { + reasoning := openrouter.RequestReasoning{ + MaxTokens: claudeRequest.Thinking.BudgetTokens, + } + reasoningJSON, err := json.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("failed to marshal reasoning: %w", err) + } + openAIRequest.Reasoning = reasoningJSON + } else { + thinkingSuffix := "-thinking" + if strings.HasSuffix(info.OriginModelName, thinkingSuffix) && + !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) { + openAIRequest.Model = openAIRequest.Model + thinkingSuffix + } } } @@ -62,16 +77,30 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re } else { systems := claudeRequest.ParseSystem() if len(systems) > 0 { - systemStr := "" openAIMessage := dto.Message{ Role: "system", } - for _, system := range systems { - if system.Text != nil { - systemStr += *system.Text + isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude") + if isOpenRouterClaude { + systemMediaMessages := make([]dto.MediaContent, 0, len(systems)) + for _, system := range systems { + message := dto.MediaContent{ + Type: "text", + Text: system.GetText(), + CacheControl: system.CacheControl, + } + systemMediaMessages = append(systemMediaMessages, message) } + openAIMessage.SetMediaContent(systemMediaMessages) + } else { + systemStr := "" + for _, system := range systems { + if system.Text != nil { + systemStr += *system.Text + } + } + openAIMessage.SetStringContent(systemStr) } - openAIMessage.SetStringContent(systemStr) openAIMessages = append(openAIMessages, openAIMessage) } } @@ -97,8 +126,9 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re switch mediaMsg.Type { case "text": message := dto.MediaContent{ - Type: "text", - Text: mediaMsg.GetText(), + Type: "text", + Text: mediaMsg.GetText(), + CacheControl: mediaMsg.CacheControl, } mediaMessages = append(mediaMessages, message) case "image": diff --git a/service/http_client.go b/service/http_client.go index c3f8df7a..64a361cf 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -3,12 +3,13 @@ package service import ( "context" "fmt" - "golang.org/x/net/proxy" "net" "net/http" "net/url" "one-api/common" "time" + + "golang.org/x/net/proxy" ) var httpClient *http.Client @@ -55,7 +56,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { }, }, nil - case "socks5": + case "socks5", "socks5h": // 获取认证信息 var auth *proxy.Auth if parsedURL.User != nil { @@ -69,6 +70,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { } // 创建 SOCKS5 代理拨号器 + // proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同 dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) if err != nil { return nil, err diff --git a/service/token_counter.go b/service/token_counter.go index f3c3b6b0..d63b54ad 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -120,11 +120,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m var config image.Config var err error var format string + var b64str string if strings.HasPrefix(imageUrl.Url, "http") { config, format, err = DecodeUrlImageData(imageUrl.Url) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, format, _, err = DecodeBase64ImageData(imageUrl.Url) + config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) } if err != nil { return 0, err @@ -132,7 +133,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m imageUrl.MimeType = format if config.Width == 0 || config.Height == 0 { - return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) + // not an image + if format != "" && b64str != "" { + // file type + return 3 * baseTokens, nil + } + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) } shortSide := config.Width @@ -400,6 +406,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod tokenNum += 100 } else if m.Type == dto.ContentTypeFile { tokenNum += 5000 + } else if m.Type == dto.ContentTypeVideoUrl { + tokenNum += 5000 } else { tokenNum += getTokenNum(tokenEncoder, m.Text) } diff --git a/setting/operation_setting/cache_ratio.go b/setting/operation_setting/cache_ratio.go index dd29eac2..ec0c766d 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/operation_setting/cache_ratio.go @@ -36,6 +36,10 @@ var defaultCacheRatio = map[string]float64{ "claude-3-5-sonnet-20241022": 0.1, "claude-3-7-sonnet-20250219": 0.1, "claude-3-7-sonnet-20250219-thinking": 0.1, + "claude-sonnet-4-20250514": 0.1, + "claude-sonnet-4-20250514-thinking": 0.1, + "claude-opus-4-20250514": 0.1, + "claude-opus-4-20250514-thinking": 0.1, } var defaultCreateCacheRatio = map[string]float64{ @@ -47,6 +51,10 @@ var defaultCreateCacheRatio = map[string]float64{ "claude-3-5-sonnet-20241022": 1.25, "claude-3-7-sonnet-20250219": 1.25, "claude-3-7-sonnet-20250219-thinking": 1.25, + "claude-sonnet-4-20250514": 1.25, + "claude-sonnet-4-20250514-thinking": 1.25, + "claude-opus-4-20250514": 1.25, + "claude-opus-4-20250514-thinking": 1.25, } //var defaultCreateCacheRatio = map[string]float64{} diff --git a/setting/operation_setting/model-ratio.go b/setting/operation_setting/model-ratio.go index fdc1c950..700a7c4e 100644 --- a/setting/operation_setting/model-ratio.go +++ b/setting/operation_setting/model-ratio.go @@ -114,7 +114,9 @@ var defaultModelRatio = map[string]float64{ "claude-3-5-sonnet-20241022": 1.5, "claude-3-7-sonnet-20250219": 1.5, "claude-3-7-sonnet-20250219-thinking": 1.5, + "claude-sonnet-4-20250514": 1.5, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens + "claude-opus-4-20250514": 7.5, "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB, @@ -440,13 +442,15 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { if name == "chatgpt-4o-latest" { return 3, true } - if strings.Contains(name, "claude-instant-1") { - return 3, true - } else if strings.Contains(name, "claude-2") { - return 3, true - } else if strings.Contains(name, "claude-3") { + + if strings.Contains(name, "claude-3") { return 5, true + } else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") { + return 5, true + } else if strings.Contains(name, "claude-instant-1") || strings.Contains(name, "claude-2") { + return 3, true } + if strings.HasPrefix(name, "gpt-3.5") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go new file mode 100644 index 00000000..974c4ed2 --- /dev/null +++ b/setting/operation_setting/tools.go @@ -0,0 +1,57 @@ +package operation_setting + +import "strings" + +const ( + // Web search + WebSearchHighTierModelPriceLow = 30.00 + WebSearchHighTierModelPriceMedium = 35.00 + WebSearchHighTierModelPriceHigh = 50.00 + WebSearchPriceLow = 25.00 + WebSearchPriceMedium = 27.50 + WebSearchPriceHigh = 30.00 + // File search + FileSearchPrice = 2.5 +) + +func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 { + // 确定模型类型 + // https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费 + // gpt-4.1, gpt-4o, or gpt-4o-search-preview 更贵,gpt-4.1-mini, gpt-4o-mini, gpt-4o-mini-search-preview 更便宜 + isHighTierModel := (strings.HasPrefix(modelName, "gpt-4.1") || strings.HasPrefix(modelName, "gpt-4o")) && + !strings.Contains(modelName, "mini") + // 确定 search context size 对应的价格 + var priceWebSearchPerThousandCalls float64 + switch contextSize { + case "low": + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceLow + } else { + priceWebSearchPerThousandCalls = WebSearchPriceLow + } + case "medium": + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium + } else { + priceWebSearchPerThousandCalls = WebSearchPriceMedium + } + case "high": + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceHigh + } else { + priceWebSearchPerThousandCalls = WebSearchPriceHigh + } + default: + // search context size 默认为 medium + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium + } else { + priceWebSearchPerThousandCalls = WebSearchPriceMedium + } + } + return priceWebSearchPerThousandCalls +} + +func GetFileSearchPricePerThousand() float64 { + return FileSearchPrice +} diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 4b216948..53b53f88 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -1,6 +1,64 @@ package setting +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" +) + var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 +var ModelRequestRateLimitGroup = map[string][2]int{} +var ModelRequestRateLimitMutex sync.RWMutex + +func ModelRequestRateLimitGroup2JSONString() string { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) + if err != nil { + common.SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + ModelRequestRateLimitGroup = make(map[string][2]int) + return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup) +} + +func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + if ModelRequestRateLimitGroup == nil { + return 0, 0, false + } + + limits, found := ModelRequestRateLimitGroup[group] + if !found { + return 0, 0, false + } + return limits[0], limits[1], true +} + +func CheckModelRequestRateLimitGroup(jsonStr string) error { + checkModelRequestRateLimitGroup := make(map[string][2]int) + err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup) + if err != nil { + return err + } + for group, limits := range checkModelRequestRateLimitGroup { + if limits[0] < 0 || limits[1] < 1 { + return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) + } + } + + return nil +} diff --git a/setting/system_setting.go b/setting/system_setting.go index 15017d3d..c37a6123 100644 --- a/setting/system_setting.go +++ b/setting/system_setting.go @@ -3,6 +3,7 @@ package setting var ServerAddress = "http://localhost:3000" var WorkerUrl = "" var WorkerValidKey = "" +var WorkerAllowHttpImageRequestEnabled = false func EnableWorker() bool { return WorkerUrl != "" diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 3425beea..f490e14a 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -871,7 +871,16 @@ const ChannelsTable = () => { }; const refresh = async () => { - await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + } else { + await searchChannels( + searchKeyword, + searchGroup, + searchModel, + enableTagMode, + ); + } }; useEffect(() => { @@ -879,9 +888,13 @@ const ChannelsTable = () => { const localIdSort = localStorage.getItem('id-sort') === 'true'; const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + const localEnableTagMode = localStorage.getItem('enable-tag-mode') === 'true'; + const localEnableBatchDelete = localStorage.getItem('enable-batch-delete') === 'true'; setIdSort(localIdSort); setPageSize(localPageSize); - loadChannels(0, localPageSize, localIdSort, enableTagMode) + setEnableTagMode(localEnableTagMode); + setEnableBatchDelete(localEnableBatchDelete); + loadChannels(0, localPageSize, localIdSort, localEnableTagMode) .then() .catch((reason) => { showError(reason); @@ -979,8 +992,8 @@ const ChannelsTable = () => { enableTagMode, ) => { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { - await loadChannels(0, pageSize, idSort, enableTagMode); - setActivePage(1); + await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + // setActivePage(1); return; } setSearching(true); @@ -1477,10 +1490,12 @@ const ChannelsTable = () => { {t('开启批量操作')} { + localStorage.setItem('enable-batch-delete', v + ''); setEnableBatchDelete(v); }} /> @@ -1544,6 +1559,7 @@ const ChannelsTable = () => { uncheckedText={t('关')} aria-label={t('是否启用标签聚合')} onChange={(v) => { + localStorage.setItem('enable-tag-mode', v + ''); setEnableTagMode(v); loadChannels(0, pageSize, idSort, v); }} diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 903677eb..6cf7e844 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -618,7 +618,6 @@ const LogsTable = () => { ); } - let content = other?.claude ? renderClaudeModelPriceSimple( other.model_ratio, @@ -935,6 +934,13 @@ const LogsTable = () => { other.model_price, other.group_ratio, other?.user_group_ratio, + false, + 1.0, + undefined, + other.web_search || false, + other.web_search_call_count || 0, + other.file_search || false, + other.file_search_call_count || 0, ), }); } @@ -995,6 +1001,12 @@ const LogsTable = () => { other?.image || false, other?.image_ratio || 0, other?.image_output || 0, + other?.web_search || false, + other?.web_search_call_count || 0, + other?.web_search_price || 0, + other?.file_search || false, + other?.file_search_call_count || 0, + other?.file_search_price || 0, ); } expandDataLocal.push({ diff --git a/web/src/components/ModelSetting.js b/web/src/components/ModelSetting.js index 2a566d6b..9c60a390 100644 --- a/web/src/components/ModelSetting.js +++ b/web/src/components/ModelSetting.js @@ -39,7 +39,9 @@ const ModelSetting = () => { item.key === 'claude.default_max_tokens' || item.key === 'gemini.supported_imagine_models' ) { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); + if (item.value !== '') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } } if (item.key.endsWith('Enabled') || item.key.endsWith('enabled')) { newInputs[item.key] = item.value === 'true' ? true : false; @@ -60,6 +62,7 @@ const ModelSetting = () => { // showSuccess('刷新成功'); } catch (error) { showError('刷新失败'); + console.error(error); } finally { setLoading(false); } diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index d1e03db2..0f52c319 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -57,6 +57,7 @@ const PersonalSetting = () => { email_verification_code: '', email: '', self_account_deletion_confirmation: '', + original_password: '', set_new_password: '', set_new_password_confirmation: '', }); @@ -239,11 +240,24 @@ const PersonalSetting = () => { }; const changePassword = async () => { + if (inputs.original_password === '') { + showError(t('请输入原密码!')); + return; + } + if (inputs.set_new_password === '') { + showError(t('请输入新密码!')); + return; + } + if (inputs.original_password === inputs.set_new_password) { + showError(t('新密码需要和原密码不一致!')); + return; + } if (inputs.set_new_password !== inputs.set_new_password_confirmation) { showError(t('两次输入的密码不一致!')); return; } const res = await API.put(`/api/user/self`, { + original_password: inputs.original_password, password: inputs.set_new_password, }); const { success, message } = res.data; @@ -816,8 +830,8 @@ const PersonalSetting = () => {
- - + +
{t('通知方式')}
@@ -993,23 +1007,36 @@ const PersonalSetting = () => {
- +
- {t('接受未设置价格模型')} + + {t('接受未设置价格模型')} +
handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)} + checked={ + notificationSettings.acceptUnsetModelRatioModel + } + onChange={(e) => + handleNotificationSettingChange( + 'acceptUnsetModelRatioModel', + e.target.checked, + ) + } > {t('接受未设置价格模型')} - - {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} + + {t( + '当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用', + )}
-
@@ -799,7 +812,13 @@ const SystemSetting = () => { onChange={(value) => setEmailToAdd(value)} style={{ marginTop: 16 }} suffix={ - + } onEnterPress={handleAddEmail} /> diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index fa59bcce..054da535 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [ { value: 48, color: 'blue', - label: 'xAI' - } + label: 'xAI', + }, + { + value: 49, + color: 'blue', + label: 'Coze', + }, ]; diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 7b80da6f..5a59356b 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -317,6 +317,12 @@ export function renderModelPrice( image = false, imageRatio = 1.0, imageOutputTokens = 0, + webSearch = false, + webSearchCallCount = 0, + webSearchPrice = 0, + fileSearch = false, + fileSearchCallCount = 0, + fileSearchPrice = 0, ) { if (modelPrice !== -1) { return i18next.t( @@ -339,14 +345,17 @@ export function renderModelPrice( // Calculate effective input tokens (non-cached + cached with ratio applied) let effectiveInputTokens = inputTokens - cacheTokens + cacheTokens * cacheRatio; -// Handle image tokens if present + // Handle image tokens if present if (image && imageOutputTokens > 0) { - effectiveInputTokens = inputTokens - imageOutputTokens + imageOutputTokens * imageRatio; + effectiveInputTokens = + inputTokens - imageOutputTokens + imageOutputTokens * imageRatio; } let price = (effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio + - (completionTokens / 1000000) * completionRatioPrice * groupRatio; + (completionTokens / 1000000) * completionRatioPrice * groupRatio + + (webSearchCallCount / 1000) * webSearchPrice * groupRatio + + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio; return ( <> @@ -391,9 +400,23 @@ export function renderModelPrice( )}

)} + {webSearch && webSearchCallCount > 0 && ( +

+ {i18next.t('Web搜索价格:${{price}} / 1K 次', { + price: webSearchPrice, + })} +

+ )} + {fileSearch && fileSearchCallCount > 0 && ( +

+ {i18next.t('文件搜索价格:${{price}} / 1K 次', { + price: fileSearchPrice, + })} +

+ )}

- {cacheTokens > 0 && !image + {cacheTokens > 0 && !image && !webSearch && !fileSearch ? i18next.t( '输入 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', { @@ -407,31 +430,82 @@ export function renderModelPrice( total: price.toFixed(6), }, ) - : image && imageOutputTokens > 0 - ? i18next.t( - '输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', - { - nonImageInput: inputTokens - imageOutputTokens, - imageInput: imageOutputTokens, - imageRatio: imageRatio, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - ratio: groupRatio, - total: price.toFixed(6), - }, - ) - : i18next.t( - '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', - { - input: inputTokens, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - ratio: groupRatio, - total: price.toFixed(6), - }, - )} + : image && imageOutputTokens > 0 && !webSearch && !fileSearch + ? i18next.t( + '输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', + { + nonImageInput: inputTokens - imageOutputTokens, + imageInput: imageOutputTokens, + imageRatio: imageRatio, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6), + }, + ) + : webSearch && webSearchCallCount > 0 && !image && !fileSearch + ? i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}} = ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + webSearchCallCount, + webSearchPrice, + total: price.toFixed(6), + }, + ) + : fileSearch && + fileSearchCallCount > 0 && + !image && + !webSearch + ? i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + fileSearchCallCount, + fileSearchPrice, + total: price.toFixed(6), + }, + ) + : webSearch && + webSearchCallCount > 0 && + fileSearch && + fileSearchCallCount > 0 && + !image + ? i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}}+ 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + webSearchCallCount, + webSearchPrice, + fileSearchCallCount, + fileSearchPrice, + total: price.toFixed(6), + }, + ) + : i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6), + }, + )}

{i18next.t('仅供参考,以实际扣费为准')}

@@ -448,33 +522,56 @@ export function renderLogContent( user_group_ratio, image = false, imageRatio = 1.0, - useUserGroupRatio = undefined + useUserGroupRatio = undefined, + webSearch = false, + webSearchCallCount = 0, + fileSearch = false, + fileSearchCallCount = 0, ) { - const ratioLabel = useUserGroupRatio ? i18next.t('专属倍率') : i18next.t('分组倍率'); + const ratioLabel = useUserGroupRatio + ? i18next.t('专属倍率') + : i18next.t('分组倍率'); const ratio = useUserGroupRatio ? user_group_ratio : groupRatio; if (modelPrice !== -1) { return i18next.t('模型价格 ${{price}},{{ratioType}} {{ratio}}', { price: modelPrice, ratioType: ratioLabel, - ratio + ratio, }); } else { if (image) { - return i18next.t('模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}},{{ratioType}} {{ratio}}', { - modelRatio: modelRatio, - completionRatio: completionRatio, - imageRatio: imageRatio, - ratioType: ratioLabel, - ratio - }); + return i18next.t( + '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}},{{ratioType}} {{ratio}}', + { + modelRatio: modelRatio, + completionRatio: completionRatio, + imageRatio: imageRatio, + ratioType: ratioLabel, + ratio, + }, + ); + } else if (webSearch) { + return i18next.t( + '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}},Web 搜索调用 {{webSearchCallCount}} 次', + { + modelRatio: modelRatio, + completionRatio: completionRatio, + ratioType: ratioLabel, + ratio, + webSearchCallCount, + }, + ); } else { - return i18next.t('模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}}', { - modelRatio: modelRatio, - completionRatio: completionRatio, - ratioType: ratioLabel, - ratio - }); + return i18next.t( + '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}}', + { + modelRatio: modelRatio, + completionRatio: completionRatio, + ratioType: ratioLabel, + ratio, + }, + ); } } } diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index e9975f61..916329e7 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -493,6 +493,7 @@ "默认": "default", "图片演示": "Image demo", "注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.1 will be requested as gpt-41, so when deploying on Azure, the deployment model name needs to be manually changed to gpt-41", + "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "After May 10, 2025, channels added do not need to remove the dot in the model name during deployment", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", "取消无限额度": "Cancel unlimited quota", "取消": "Cancel", @@ -1085,7 +1086,7 @@ "没有账户?": "No account? ", "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com", "默认 API 版本": "Default API Version", - "请输入默认 API 版本,例如:2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.", + "请输入默认 API 版本,例如:2025-04-01-preview": "Please enter default API version, e.g.: 2025-04-01-preview.", "请为渠道命名": "Please name the channel", "请选择可以使用该渠道的分组": "Please select groups that can use this channel", "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:", @@ -1373,4 +1374,4 @@ "适用于展示系统功能的场景。": "Suitable for scenarios where the system functions are displayed.", "可在初始化后修改": "Can be modified after initialization", "初始化系统": "Initialize system" -} +} \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index a793e149..f7fab057 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -24,7 +24,8 @@ import { TextArea, Checkbox, Banner, - Modal, ImagePreview + Modal, + ImagePreview, } from '@douyinfe/semi-ui'; import { getChannelModels, loadChannelModels } from '../../components/utils.js'; import { IconHelpCircle } from '@douyinfe/semi-icons'; @@ -306,7 +307,7 @@ const EditChannel = (props) => { fetchModels().then(); fetchGroups().then(); if (isEdit) { - loadChannel().then(() => { }); + loadChannel().then(() => {}); } else { setInputs(originInputs); let localModels = getChannelModels(inputs.type); @@ -477,24 +478,26 @@ const EditChannel = (props) => { type={'warning'} description={ <> - {t('注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41')} -
- { - setModalImageUrl( - '/azure_model_name.png', - ); - setIsModalOpenurl(true) + {t( + '2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."', + )} + {/*
*/} + {/* {*/} + {/* setModalImageUrl(*/} + {/* '/azure_model_name.png',*/} + {/* );*/} + {/* setIsModalOpenurl(true)*/} - }} - > - {t('查看示例')} -
+ {/* }}*/} + {/*>*/} + {/* {t('查看示例')}*/} + {/**/} } > @@ -522,7 +525,7 @@ const EditChannel = (props) => { { handleInputChange('other', value); }} @@ -584,25 +587,35 @@ const EditChannel = (props) => { value={inputs.name} autoComplete='new-password' /> - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && ( - <> -
- {t('API地址')}: -
- - { - handleInputChange('base_url', value); - }} - value={inputs.base_url} - autoComplete="new-password" - /> - - - )} + {inputs.type !== 3 && + inputs.type !== 8 && + inputs.type !== 22 && + inputs.type !== 36 && + inputs.type !== 45 && ( + <> +
+ {t('API地址')}: +
+ + { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + + )}
{t('密钥')}:
@@ -761,10 +774,10 @@ const EditChannel = (props) => { name='other' placeholder={t( '请输入部署地区,例如:us-central1\n支持使用模型映射格式\n' + - '{\n' + - ' "default": "us-central1",\n' + - ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + - '}', + '{\n' + + ' "default": "us-central1",\n' + + ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + + '}', )} autosize={{ minRows: 2 }} onChange={(value) => { @@ -825,6 +838,22 @@ const EditChannel = (props) => { /> )} + {inputs.type === 49 && ( + <> +
+ 智能体ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
{t('模型')}:
diff --git a/web/src/pages/Home/index.js b/web/src/pages/Home/index.js index 599c7930..84fabf6f 100644 --- a/web/src/pages/Home/index.js +++ b/web/src/pages/Home/index.js @@ -158,7 +158,7 @@ const Home = () => {

{t('OIDC 身份验证')}: - {statusState?.status?.oidc === true + {statusState?.status?.oidc_enabled === true ? t('已启用') : t('未启用')}

diff --git a/web/src/pages/Playground/Playground.js b/web/src/pages/Playground/Playground.js index e8138c01..08eada17 100644 --- a/web/src/pages/Playground/Playground.js +++ b/web/src/pages/Playground/Playground.js @@ -64,8 +64,9 @@ const Playground = () => { }, ]; + const defaultModel = 'gpt-4o-mini'; const [inputs, setInputs] = useState({ - model: 'gpt-4o-mini', + model: defaultModel, group: '', max_tokens: 0, temperature: 0, @@ -108,6 +109,11 @@ const Playground = () => { value: model, })); setModels(localModelOptions); + // if default model is not in the list, set the first one as default + const hasDefault = localModelOptions.some(option => option.value === defaultModel); + if (!hasDefault && localModelOptions.length > 0) { + setInputs((inputs) => ({ ...inputs, model: localModelOptions[0].value })); + } } else { showError(t(message)); } diff --git a/web/src/pages/Setting/Model/SettingGeminiModel.js b/web/src/pages/Setting/Model/SettingGeminiModel.js index 6f6da279..b802af1a 100644 --- a/web/src/pages/Setting/Model/SettingGeminiModel.js +++ b/web/src/pages/Setting/Model/SettingGeminiModel.js @@ -27,40 +27,48 @@ export default function SettingGeminiModel(props) { const [inputs, setInputs] = useState({ 'gemini.safety_settings': '', 'gemini.version_settings': '', - 'gemini.supported_imagine_models': [], + 'gemini.supported_imagine_models': '', 'gemini.thinking_adapter_enabled': false, 'gemini.thinking_adapter_budget_tokens_percentage': 0.6, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); - function onSubmit() { - const updateArray = compareObjects(inputs, inputsRow); - if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); - const requestQueue = updateArray.map((item) => { - let value = String(inputs[item.key]); - return API.put('/api/option/', { - key: item.key, - value, - }); - }); - setLoading(true); - Promise.all(requestQueue) - .then((res) => { - if (requestQueue.length === 1) { - if (res.includes(undefined)) return; - } else if (requestQueue.length > 1) { - if (res.includes(undefined)) - return showError(t('部分保存失败,请重试')); - } - showSuccess(t('保存成功')); - props.refresh(); + async function onSubmit() { + await refForm.current + .validate() + .then(() => { + const updateArray = compareObjects(inputs, inputsRow); + if (!updateArray.length) return showWarning(t('你似乎并没有修改什么')); + const requestQueue = updateArray.map((item) => { + let value = String(inputs[item.key]); + return API.put('/api/option/', { + key: item.key, + value, + }); + }); + setLoading(true); + Promise.all(requestQueue) + .then((res) => { + if (requestQueue.length === 1) { + if (res.includes(undefined)) return; + } else if (requestQueue.length > 1) { + if (res.includes(undefined)) + return showError(t('部分保存失败,请重试')); + } + showSuccess(t('保存成功')); + props.refresh(); + }) + .catch(() => { + showError(t('保存失败,请重试')); + }) + .finally(() => { + setLoading(false); + }); }) - .catch(() => { - showError(t('保存失败,请重试')); - }) - .finally(() => { - setLoading(false); + .catch((error) => { + console.error('Validation failed:', error); + showError(t('请检查输入')); }); } @@ -146,6 +154,14 @@ export default function SettingGeminiModel(props) { label={t('支持的图像模型')} placeholder={t('例如:') + '\n' + JSON.stringify(['gemini-2.0-flash-exp-image-generation'], null, 2)} onChange={(value) => setInputs({ ...inputs, 'gemini.supported_imagine_models': value })} + trigger='blur' + stopValidateWithError + rules={[ + { + validator: (rule, value) => verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} /> diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 800e9636..73626351 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -6,6 +6,7 @@ import { showError, showSuccess, showWarning, + verifyJSON, } from '../../../helpers'; import { useTranslation } from 'react-i18next'; @@ -18,6 +19,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -46,6 +48,13 @@ export default function RequestRateLimit(props) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } + + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } + showSuccess(t('保存成功')); props.refresh(); }) @@ -147,6 +156,41 @@ export default function RequestRateLimit(props) { /> + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} + extraText={ +
+

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
+
+ } + onChange={(value) => { + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); + }} + /> + +