- Introduced `isNoThinkingRequest` and `trimModelThinking` functions to manage model names and thinking configurations. - Updated `GeminiHelper` to conditionally adjust the model name based on the thinking budget and request settings. - Refactored `ThinkingAdaptor` to streamline the integration of thinking capabilities into Gemini requests. - Cleaned up commented-out code in `FetchUpstreamModels` for clarity. These changes improve the handling of model configurations and enhance the adaptability of the Gemini relay system.
697 lines
15 KiB
Go
697 lines
15 KiB
Go
package controller
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/model"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type OpenAIModel struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
Permission []struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||
AllowSampling bool `json:"allow_sampling"`
|
||
AllowLogprobs bool `json:"allow_logprobs"`
|
||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||
AllowView bool `json:"allow_view"`
|
||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||
Organization string `json:"organization"`
|
||
Group string `json:"group"`
|
||
IsBlocking bool `json:"is_blocking"`
|
||
} `json:"permission"`
|
||
Root string `json:"root"`
|
||
Parent string `json:"parent"`
|
||
}
|
||
|
||
type OpenAIModelsResponse struct {
|
||
Data []OpenAIModel `json:"data"`
|
||
Success bool `json:"success"`
|
||
}
|
||
|
||
func GetAllChannels(c *gin.Context) {
|
||
p, _ := strconv.Atoi(c.Query("p"))
|
||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||
if p < 1 {
|
||
p = 1
|
||
}
|
||
if pageSize < 1 {
|
||
pageSize = common.ItemsPerPage
|
||
}
|
||
channelData := make([]*model.Channel, 0)
|
||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||
// type filter
|
||
typeStr := c.Query("type")
|
||
typeFilter := -1
|
||
if typeStr != "" {
|
||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||
typeFilter = t
|
||
}
|
||
}
|
||
|
||
var total int64
|
||
|
||
if enableTagMode {
|
||
// tag 分页:先分页 tag,再取各 tag 下 channels
|
||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||
return
|
||
}
|
||
for _, tag := range tags {
|
||
if tag != nil && *tag != "" {
|
||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||
if err == nil {
|
||
channelData = append(channelData, tagChannel...)
|
||
}
|
||
}
|
||
}
|
||
// 计算 tag 总数用于分页
|
||
total, _ = model.CountAllTags()
|
||
} else if typeFilter >= 0 {
|
||
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||
return
|
||
}
|
||
channelData = channels
|
||
total, _ = model.CountChannelsByType(typeFilter)
|
||
} else {
|
||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||
return
|
||
}
|
||
channelData = channels
|
||
total, _ = model.CountAllChannels()
|
||
}
|
||
|
||
// calculate type counts
|
||
typeCounts, _ := model.CountChannelsGroupByType()
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": gin.H{
|
||
"items": channelData,
|
||
"total": total,
|
||
"page": p,
|
||
"page_size": pageSize,
|
||
"type_counts": typeCounts,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
func FetchUpstreamModels(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
channel, err := model.GetChannelById(id, true)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||
if channel.GetBaseURL() != "" {
|
||
baseURL = channel.GetBaseURL()
|
||
}
|
||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||
switch channel.Type {
|
||
case common.ChannelTypeGemini:
|
||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||
case common.ChannelTypeAli:
|
||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||
}
|
||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var result OpenAIModelsResponse
|
||
if err = json.Unmarshal(body, &result); err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||
})
|
||
return
|
||
}
|
||
|
||
var ids []string
|
||
for _, model := range result.Data {
|
||
id := model.ID
|
||
if channel.Type == common.ChannelTypeGemini {
|
||
id = strings.TrimPrefix(id, "models/")
|
||
}
|
||
ids = append(ids, id)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": ids,
|
||
})
|
||
}
|
||
|
||
func FixChannelsAbilities(c *gin.Context) {
|
||
count, err := model.FixAbility()
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": count,
|
||
})
|
||
}
|
||
|
||
func SearchChannels(c *gin.Context) {
|
||
keyword := c.Query("keyword")
|
||
group := c.Query("group")
|
||
modelKeyword := c.Query("model")
|
||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||
channelData := make([]*model.Channel, 0)
|
||
if enableTagMode {
|
||
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
for _, tag := range tags {
|
||
if tag != nil && *tag != "" {
|
||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||
if err == nil {
|
||
channelData = append(channelData, tagChannel...)
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
channelData = channels
|
||
}
|
||
|
||
// calculate type counts for search results
|
||
typeCounts := make(map[int64]int64)
|
||
for _, channel := range channelData {
|
||
typeCounts[int64(channel.Type)]++
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": gin.H{
|
||
"items": channelData,
|
||
"type_counts": typeCounts,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
func GetChannel(c *gin.Context) {
|
||
id, err := strconv.Atoi(c.Param("id"))
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
channel, err := model.GetChannelById(id, false)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": channel,
|
||
})
|
||
return
|
||
}
|
||
|
||
func AddChannel(c *gin.Context) {
|
||
channel := model.Channel{}
|
||
err := c.ShouldBindJSON(&channel)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
channel.CreatedTime = common.GetTimestamp()
|
||
keys := strings.Split(channel.Key, "\n")
|
||
if channel.Type == common.ChannelTypeVertexAi {
|
||
if channel.Other == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "部署地区不能为空",
|
||
})
|
||
return
|
||
} else {
|
||
if common.IsJsonStr(channel.Other) {
|
||
// must have default
|
||
regionMap := common.StrToMap(channel.Other)
|
||
if regionMap["default"] == nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "部署地区必须包含default字段",
|
||
})
|
||
return
|
||
}
|
||
}
|
||
}
|
||
keys = []string{channel.Key}
|
||
}
|
||
channels := make([]model.Channel, 0, len(keys))
|
||
for _, key := range keys {
|
||
if key == "" {
|
||
continue
|
||
}
|
||
localChannel := channel
|
||
localChannel.Key = key
|
||
// Validate the length of the model name
|
||
models := strings.Split(localChannel.Models, ",")
|
||
for _, model := range models {
|
||
if len(model) > 255 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": fmt.Sprintf("模型名称过长: %s", model),
|
||
})
|
||
return
|
||
}
|
||
}
|
||
channels = append(channels, localChannel)
|
||
}
|
||
err = model.BatchInsertChannels(channels)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func DeleteChannel(c *gin.Context) {
|
||
id, _ := strconv.Atoi(c.Param("id"))
|
||
channel := model.Channel{Id: id}
|
||
err := channel.Delete()
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func DeleteDisabledChannel(c *gin.Context) {
|
||
rows, err := model.DeleteDisabledChannel()
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": rows,
|
||
})
|
||
return
|
||
}
|
||
|
||
type ChannelTag struct {
|
||
Tag string `json:"tag"`
|
||
NewTag *string `json:"new_tag"`
|
||
Priority *int64 `json:"priority"`
|
||
Weight *uint `json:"weight"`
|
||
ModelMapping *string `json:"model_mapping"`
|
||
Models *string `json:"models"`
|
||
Groups *string `json:"groups"`
|
||
}
|
||
|
||
func DisableTagChannels(c *gin.Context) {
|
||
channelTag := ChannelTag{}
|
||
err := c.ShouldBindJSON(&channelTag)
|
||
if err != nil || channelTag.Tag == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.DisableChannelByTag(channelTag.Tag)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func EnableTagChannels(c *gin.Context) {
|
||
channelTag := ChannelTag{}
|
||
err := c.ShouldBindJSON(&channelTag)
|
||
if err != nil || channelTag.Tag == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.EnableChannelByTag(channelTag.Tag)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
func EditTagChannels(c *gin.Context) {
|
||
channelTag := ChannelTag{}
|
||
err := c.ShouldBindJSON(&channelTag)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
if channelTag.Tag == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "tag不能为空",
|
||
})
|
||
return
|
||
}
|
||
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
})
|
||
return
|
||
}
|
||
|
||
type ChannelBatch struct {
|
||
Ids []int `json:"ids"`
|
||
Tag *string `json:"tag"`
|
||
}
|
||
|
||
func DeleteChannelBatch(c *gin.Context) {
|
||
channelBatch := ChannelBatch{}
|
||
err := c.ShouldBindJSON(&channelBatch)
|
||
if err != nil || len(channelBatch.Ids) == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.BatchDeleteChannels(channelBatch.Ids)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": len(channelBatch.Ids),
|
||
})
|
||
return
|
||
}
|
||
|
||
func UpdateChannel(c *gin.Context) {
|
||
channel := model.Channel{}
|
||
err := c.ShouldBindJSON(&channel)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
if channel.Type == common.ChannelTypeVertexAi {
|
||
if channel.Other == "" {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "部署地区不能为空",
|
||
})
|
||
return
|
||
} else {
|
||
if common.IsJsonStr(channel.Other) {
|
||
// must have default
|
||
regionMap := common.StrToMap(channel.Other)
|
||
if regionMap["default"] == nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "部署地区必须包含default字段",
|
||
})
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
err = channel.Update()
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": channel,
|
||
})
|
||
return
|
||
}
|
||
|
||
func FetchModels(c *gin.Context) {
|
||
var req struct {
|
||
BaseURL string `json:"base_url"`
|
||
Type int `json:"type"`
|
||
Key string `json:"key"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "Invalid request",
|
||
})
|
||
return
|
||
}
|
||
|
||
baseURL := req.BaseURL
|
||
if baseURL == "" {
|
||
baseURL = common.ChannelBaseURLs[req.Type]
|
||
}
|
||
|
||
client := &http.Client{}
|
||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||
|
||
request, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// remove line breaks and extra spaces.
|
||
key := strings.TrimSpace(req.Key)
|
||
// If the key contains a line break, only take the first part.
|
||
key = strings.Split(key, "\n")[0]
|
||
request.Header.Set("Authorization", "Bearer "+key)
|
||
|
||
response, err := client.Do(request)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
//check status code
|
||
if response.StatusCode != http.StatusOK {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": "Failed to fetch models",
|
||
})
|
||
return
|
||
}
|
||
defer response.Body.Close()
|
||
|
||
var result struct {
|
||
Data []struct {
|
||
ID string `json:"id"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var models []string
|
||
for _, model := range result.Data {
|
||
models = append(models, model.ID)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": models,
|
||
})
|
||
}
|
||
|
||
func BatchSetChannelTag(c *gin.Context) {
|
||
channelBatch := ChannelBatch{}
|
||
err := c.ShouldBindJSON(&channelBatch)
|
||
if err != nil || len(channelBatch.Ids) == 0 {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": "参数错误",
|
||
})
|
||
return
|
||
}
|
||
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
|
||
if err != nil {
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": len(channelBatch.Ids),
|
||
})
|
||
return
|
||
}
|
||
|
||
func GetTagModels(c *gin.Context) {
|
||
tag := c.Query("tag")
|
||
if tag == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"success": false,
|
||
"message": "tag不能为空",
|
||
})
|
||
return
|
||
}
|
||
|
||
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"success": false,
|
||
"message": err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var longestModels string
|
||
maxLength := 0
|
||
|
||
// Find the longest models string among all channels with the given tag
|
||
for _, channel := range channels {
|
||
if channel.Models != "" {
|
||
currentModels := strings.Split(channel.Models, ",")
|
||
if len(currentModels) > maxLength {
|
||
maxLength = len(currentModels)
|
||
longestModels = channel.Models
|
||
}
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "",
|
||
"data": longestModels,
|
||
})
|
||
return
|
||
}
|