Merge branch 'alpha' into refactor/model-pricing
This commit is contained in:
@@ -65,8 +65,6 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = constant.APITypeCoze
|
apiType = constant.APITypeCoze
|
||||||
case constant.ChannelTypeJimeng:
|
case constant.ChannelTypeJimeng:
|
||||||
apiType = constant.APITypeJimeng
|
apiType = constant.APITypeJimeng
|
||||||
case constant.ChannelTypeClaudeCode:
|
|
||||||
apiType = constant.APITypeClaudeCode
|
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return constant.APITypeOpenAI, false
|
return constant.APITypeOpenAI, false
|
||||||
|
|||||||
@@ -31,6 +31,5 @@ const (
|
|||||||
APITypeXai
|
APITypeXai
|
||||||
APITypeCoze
|
APITypeCoze
|
||||||
APITypeJimeng
|
APITypeJimeng
|
||||||
APITypeClaudeCode
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ const (
|
|||||||
ChannelTypeKling = 50
|
ChannelTypeKling = 50
|
||||||
ChannelTypeJimeng = 51
|
ChannelTypeJimeng = 51
|
||||||
ChannelTypeVidu = 52
|
ChannelTypeVidu = 52
|
||||||
ChannelTypeClaudeCode = 53
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -109,5 +108,4 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.klingai.com", //50
|
"https://api.klingai.com", //50
|
||||||
"https://visual.volcengineapi.com", //51
|
"https://visual.volcengineapi.com", //51
|
||||||
"https://api.vidu.cn", //52
|
"https://api.vidu.cn", //52
|
||||||
"https://api.anthropic.com", //53
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,11 +36,30 @@ type OpenAIModel struct {
|
|||||||
Parent string `json:"parent"`
|
Parent string `json:"parent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GoogleOpenAICompatibleModels []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
InputTokenLimit int `json:"inputTokenLimit"`
|
||||||
|
OutputTokenLimit int `json:"outputTokenLimit"`
|
||||||
|
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
TopK int `json:"topK,omitempty"`
|
||||||
|
MaxTemperature int `json:"maxTemperature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type OpenAIModelsResponse struct {
|
type OpenAIModelsResponse struct {
|
||||||
Data []OpenAIModel `json:"data"`
|
Data []OpenAIModel `json:"data"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GoogleOpenAICompatibleResponse struct {
|
||||||
|
Models []GoogleOpenAICompatibleModels `json:"models"`
|
||||||
|
NextPageToken string `json:"nextPageToken"`
|
||||||
|
}
|
||||||
|
|
||||||
func parseStatusFilter(statusParam string) int {
|
func parseStatusFilter(statusParam string) int {
|
||||||
switch strings.ToLower(statusParam) {
|
switch strings.ToLower(statusParam) {
|
||||||
case "enabled", "1":
|
case "enabled", "1":
|
||||||
@@ -168,20 +187,52 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
|
||||||
|
var url string
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||||
|
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
|
||||||
case constant.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||||
|
default:
|
||||||
|
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||||
|
var body []byte
|
||||||
|
if channel.Type == constant.ChannelTypeGemini {
|
||||||
|
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
|
||||||
|
} else {
|
||||||
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
}
|
}
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var result OpenAIModelsResponse
|
var result OpenAIModelsResponse
|
||||||
|
var parseSuccess bool
|
||||||
|
|
||||||
|
// 适配特殊格式
|
||||||
|
switch channel.Type {
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
var googleResult GoogleOpenAICompatibleResponse
|
||||||
|
if err = json.Unmarshal(body, &googleResult); err == nil {
|
||||||
|
// 转换Google格式到OpenAI格式
|
||||||
|
for _, model := range googleResult.Models {
|
||||||
|
for _, gModel := range model {
|
||||||
|
result.Data = append(result.Data, OpenAIModel{
|
||||||
|
ID: gModel.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parseSuccess = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果解析失败,尝试OpenAI格式
|
||||||
|
if !parseSuccess {
|
||||||
if err = json.Unmarshal(body, &result); err != nil {
|
if err = json.Unmarshal(body, &result); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -189,6 +240,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var ids []string
|
var ids []string
|
||||||
for _, model := range result.Data {
|
for _, model := range result.Data {
|
||||||
@@ -669,6 +721,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
|||||||
type PatchChannel struct {
|
type PatchChannel struct {
|
||||||
model.Channel
|
model.Channel
|
||||||
MultiKeyMode *string `json:"multi_key_mode"`
|
MultiKeyMode *string `json:"multi_key_mode"`
|
||||||
|
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannel(c *gin.Context) {
|
func UpdateChannel(c *gin.Context) {
|
||||||
@@ -688,7 +741,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
originChannel, err := model.GetChannelById(channel.Id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -704,6 +757,69 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理多key模式下的密钥追加/覆盖逻辑
|
||||||
|
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||||||
|
switch *channel.KeyMode {
|
||||||
|
case "append":
|
||||||
|
// 追加模式:将新密钥添加到现有密钥列表
|
||||||
|
if originChannel.Key != "" {
|
||||||
|
var newKeys []string
|
||||||
|
var existingKeys []string
|
||||||
|
|
||||||
|
// 解析现有密钥
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||||||
|
// JSON数组格式
|
||||||
|
var arr []json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||||||
|
existingKeys = make([]string, len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
existingKeys[i] = string(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 换行分隔格式
|
||||||
|
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 Vertex AI 的特殊情况
|
||||||
|
if channel.Type == constant.ChannelTypeVertexAi {
|
||||||
|
// 尝试解析新密钥为JSON数组
|
||||||
|
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||||
|
array, err := getVertexArrayKeys(channel.Key)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "追加密钥解析失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newKeys = array
|
||||||
|
} else {
|
||||||
|
// 单个JSON密钥
|
||||||
|
newKeys = []string{channel.Key}
|
||||||
|
}
|
||||||
|
// 合并密钥
|
||||||
|
allKeys := append(existingKeys, newKeys...)
|
||||||
|
channel.Key = strings.Join(allKeys, "\n")
|
||||||
|
} else {
|
||||||
|
// 普通渠道的处理
|
||||||
|
inputKeys := strings.Split(channel.Key, "\n")
|
||||||
|
for _, key := range inputKeys {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key != "" {
|
||||||
|
newKeys = append(newKeys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 合并密钥
|
||||||
|
allKeys := append(existingKeys, newKeys...)
|
||||||
|
channel.Key = strings.Join(allKeys, "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "replace":
|
||||||
|
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||||
|
}
|
||||||
|
}
|
||||||
err = channel.Update()
|
err = channel.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/service"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ExchangeCodeRequest 授权码交换请求
|
|
||||||
type ExchangeCodeRequest struct {
|
|
||||||
AuthorizationCode string `json:"authorization_code" binding:"required"`
|
|
||||||
CodeVerifier string `json:"code_verifier" binding:"required"`
|
|
||||||
State string `json:"state" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateClaudeOAuthURL 生成Claude OAuth授权URL
|
|
||||||
func GenerateClaudeOAuthURL(c *gin.Context) {
|
|
||||||
params, err := service.GenerateOAuthParams()
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "生成OAuth授权URL失败: " + err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "生成OAuth授权URL成功",
|
|
||||||
"data": params,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExchangeClaudeOAuthCode 交换Claude OAuth授权码
|
|
||||||
func ExchangeClaudeOAuthCode(c *gin.Context) {
|
|
||||||
var req ExchangeCodeRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "请求参数错误: " + err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析授权码
|
|
||||||
cleanedCode, err := service.ParseAuthorizationCode(req.AuthorizationCode)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 交换token
|
|
||||||
tokenResult, err := service.ExchangeCode(cleanedCode, req.CodeVerifier, req.State, nil)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("Claude OAuth token exchange failed: " + err.Error())
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "授权码交换失败: " + err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "授权码交换成功",
|
|
||||||
"data": tokenResult,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -47,7 +47,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
|||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if constant2.ErrorLogEnabled && err != nil {
|
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
|
||||||
// 保存错误日志到mysql中
|
// 保存错误日志到mysql中
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
@@ -62,6 +62,14 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
|||||||
other["channel_id"] = channelId
|
other["channel_id"] = channelId
|
||||||
other["channel_name"] = c.GetString("channel_name")
|
other["channel_name"] = c.GetString("channel_name")
|
||||||
other["channel_type"] = c.GetInt("channel_type")
|
other["channel_type"] = c.GetInt("channel_type")
|
||||||
|
adminInfo := make(map[string]interface{})
|
||||||
|
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||||
|
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||||
|
if isMultiKey {
|
||||||
|
adminInfo["is_multi_key"] = true
|
||||||
|
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||||
|
}
|
||||||
|
other["admin_info"] = adminInfo
|
||||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package gemini
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -56,7 +56,7 @@ type FunctionCall struct {
|
|||||||
Arguments any `json:"args"`
|
Arguments any `json:"args"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FunctionResponse struct {
|
type GeminiFunctionResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Response map[string]interface{} `json:"response"`
|
Response map[string]interface{} `json:"response"`
|
||||||
}
|
}
|
||||||
@@ -81,7 +81,7 @@ type GeminiPart struct {
|
|||||||
Thought bool `json:"thought,omitempty"`
|
Thought bool `json:"thought,omitempty"`
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||||
1
go.mod
1
go.mod
@@ -87,7 +87,6 @@ require (
|
|||||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||||
golang.org/x/arch v0.12.0 // indirect
|
golang.org/x/arch v0.12.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||||
golang.org/x/oauth2 v0.30.0 // indirect
|
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.30.0 // indirect
|
||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.22.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.2 // indirect
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -231,8 +231,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
|
||||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
|
|||||||
3
main.go
3
main.go
@@ -86,9 +86,6 @@ func main() {
|
|||||||
// 数据看板
|
// 数据看板
|
||||||
go model.UpdateQuotaData()
|
go model.UpdateQuotaData()
|
||||||
|
|
||||||
// Start Claude Code token refresh scheduler
|
|
||||||
service.StartClaudeTokenRefreshScheduler()
|
|
||||||
|
|
||||||
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -269,6 +269,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
if channel.ChannelInfo.IsMultiKey {
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
|
||||||
|
} else {
|
||||||
|
// 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
|
||||||
}
|
}
|
||||||
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||||
|
|||||||
@@ -284,6 +284,21 @@ func FixAbility() (int, int, error) {
|
|||||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||||
}
|
}
|
||||||
defer fixLock.Unlock()
|
defer fixLock.Unlock()
|
||||||
|
|
||||||
|
// truncate abilities table
|
||||||
|
if common.UsingSQLite {
|
||||||
|
err := DB.Exec("DELETE FROM abilities").Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := DB.Exec("TRUNCATE TABLE abilities").Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
// Find all channels
|
// Find all channels
|
||||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ type Channel struct {
|
|||||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||||
// add after v0.8.5
|
// add after v0.8.5
|
||||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||||
|
|
||||||
|
// cache info
|
||||||
|
Keys []string `json:"-" gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelInfo struct {
|
type ChannelInfo struct {
|
||||||
@@ -71,6 +74,9 @@ func (channel *Channel) getKeys() []string {
|
|||||||
if channel.Key == "" {
|
if channel.Key == "" {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
if len(channel.Keys) > 0 {
|
||||||
|
return channel.Keys
|
||||||
|
}
|
||||||
trimmed := strings.TrimSpace(channel.Key)
|
trimmed := strings.TrimSpace(channel.Key)
|
||||||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||||||
if strings.HasPrefix(trimmed, "[") {
|
if strings.HasPrefix(trimmed, "[") {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -66,6 +67,20 @@ func InitChannelCache() {
|
|||||||
|
|
||||||
channelSyncLock.Lock()
|
channelSyncLock.Lock()
|
||||||
group2model2channels = newGroup2model2channels
|
group2model2channels = newGroup2model2channels
|
||||||
|
//channelsIDM = newChannelId2channel
|
||||||
|
for i, channel := range newChannelId2channel {
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
channel.Keys = channel.getKeys()
|
||||||
|
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||||
|
if oldChannel, ok := channelsIDM[i]; ok {
|
||||||
|
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||||||
|
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||||||
|
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
channelsIDM = newChannelId2channel
|
channelsIDM = newChannelId2channel
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
common.SysLog("channels synced from database")
|
common.SysLog("channels synced from database")
|
||||||
@@ -203,9 +218,6 @@ func CacheGetChannel(id int) (*Channel, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||||
}
|
}
|
||||||
if c.Status != common.ChannelStatusEnabled {
|
|
||||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
|
||||||
}
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,9 +236,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||||||
}
|
}
|
||||||
if c.Status != common.ChannelStatusEnabled {
|
|
||||||
return nil, fmt.Errorf("渠道# %d,已被禁用", id)
|
|
||||||
}
|
|
||||||
return &c.ChannelInfo, nil
|
return &c.ChannelInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ type Adaptor interface {
|
|||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||||
|
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskAdaptor interface {
|
type TaskAdaptor interface {
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -22,6 +22,11 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
c.Set("request_model", request.Model)
|
c.Set("request_model", request.Model)
|
||||||
c.Set("converted_request", request)
|
c.Set("converted_request", request)
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
package claude_code
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/relay/channel"
|
|
||||||
"one-api/relay/channel/claude"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/types"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
RequestModeCompletion = 1
|
|
||||||
RequestModeMessage = 2
|
|
||||||
DefaultSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
|
||||||
)
|
|
||||||
|
|
||||||
type Adaptor struct {
|
|
||||||
RequestMode int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
|
||||||
// Use configured system prompt if available, otherwise use default
|
|
||||||
if info.ChannelSetting.SystemPrompt != "" {
|
|
||||||
request.System = info.ChannelSetting.SystemPrompt
|
|
||||||
} else {
|
|
||||||
request.System = DefaultSystemPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
|
|
||||||
a.RequestMode = RequestModeCompletion
|
|
||||||
} else {
|
|
||||||
a.RequestMode = RequestModeMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
||||||
if a.RequestMode == RequestModeMessage {
|
|
||||||
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
|
||||||
|
|
||||||
// Parse accesstoken|refreshtoken format and use only the access token
|
|
||||||
accessToken := info.ApiKey
|
|
||||||
if strings.Contains(info.ApiKey, "|") {
|
|
||||||
parts := strings.Split(info.ApiKey, "|")
|
|
||||||
if len(parts) >= 1 {
|
|
||||||
accessToken = parts[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Claude Code specific headers - force override
|
|
||||||
req.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
// 只有在没有设置的情况下才设置 anthropic-version
|
|
||||||
if req.Get("anthropic-version") == "" {
|
|
||||||
req.Set("anthropic-version", "2023-06-01")
|
|
||||||
}
|
|
||||||
req.Set("content-type", "application/json")
|
|
||||||
|
|
||||||
// 只有在 user-agent 不包含 claude-cli 时才设置
|
|
||||||
userAgent := req.Get("user-agent")
|
|
||||||
if userAgent == "" || !strings.Contains(strings.ToLower(userAgent), "claude-cli") {
|
|
||||||
req.Set("user-agent", "claude-cli/1.0.61 (external, cli)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只有在 anthropic-beta 不包含 claude-code 时才设置
|
|
||||||
anthropicBeta := req.Get("anthropic-beta")
|
|
||||||
if anthropicBeta == "" || !strings.Contains(strings.ToLower(anthropicBeta), "claude-code") {
|
|
||||||
req.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14")
|
|
||||||
}
|
|
||||||
// if Anthropic-Dangerous-Direct-Browser-Access
|
|
||||||
anthropicDangerousDirectBrowserAccess := req.Get("anthropic-dangerous-direct-browser-access")
|
|
||||||
if anthropicDangerousDirectBrowserAccess == "" {
|
|
||||||
req.Set("anthropic-dangerous-direct-browser-access", "true")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
||||||
if request == nil {
|
|
||||||
return nil, errors.New("request is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.RequestMode == RequestModeCompletion {
|
|
||||||
return claude.RequestOpenAI2ClaudeComplete(*request), nil
|
|
||||||
} else {
|
|
||||||
claudeRequest, err := claude.RequestOpenAI2ClaudeMessage(*request)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use configured system prompt if available, otherwise use default
|
|
||||||
if info.ChannelSetting.SystemPrompt != "" {
|
|
||||||
claudeRequest.System = info.ChannelSetting.SystemPrompt
|
|
||||||
} else {
|
|
||||||
claudeRequest.System = DefaultSystemPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
return claudeRequest, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
|
||||||
return nil, errors.New("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
||||||
if info.IsStream {
|
|
||||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
|
||||||
} else {
|
|
||||||
err, usage = claude.ClaudeHandler(c, resp, a.RequestMode, info)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
|
||||||
return ModelList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetChannelName() string {
|
|
||||||
return ChannelName
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
package claude_code
|
|
||||||
|
|
||||||
var ModelList = []string{
|
|
||||||
"claude-3-5-haiku-20241022",
|
|
||||||
"claude-3-5-sonnet-20241022",
|
|
||||||
"claude-3-7-sonnet-20250219",
|
|
||||||
"claude-3-7-sonnet-20250219-thinking",
|
|
||||||
"claude-sonnet-4-20250514",
|
|
||||||
"claude-sonnet-4-20250514-thinking",
|
|
||||||
"claude-opus-4-20250514",
|
|
||||||
"claude-opus-4-20250514-thinking",
|
|
||||||
}
|
|
||||||
|
|
||||||
var ChannelName = "claude_code"
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
package claude_code
|
|
||||||
|
|
||||||
// Claude Code uses the same DTO structures as Claude since it's based on the same API
|
|
||||||
// This file is kept for consistency with the channel structure pattern
|
|
||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// ConvertAudioRequest implements channel.Adaptor.
|
// ConvertAudioRequest implements channel.Adaptor.
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
|
|||||||
@@ -19,6 +19,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
|||||||
BotType int
|
BotType int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -20,6 +20,26 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||||
|
if len(request.Contents) > 0 {
|
||||||
|
for i, content := range request.Contents {
|
||||||
|
if i == 0 {
|
||||||
|
if request.Contents[0].Role == "" {
|
||||||
|
request.Contents[0].Role = "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.FileData != nil {
|
||||||
|
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
|
||||||
|
part.FileData.MimeType = "video/webm"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||||
adaptor := openai.Adaptor{}
|
adaptor := openai.Adaptor{}
|
||||||
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||||
@@ -51,13 +71,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// build gemini imagen request
|
// build gemini imagen request
|
||||||
geminiRequest := GeminiImageRequest{
|
geminiRequest := dto.GeminiImageRequest{
|
||||||
Instances: []GeminiImageInstance{
|
Instances: []dto.GeminiImageInstance{
|
||||||
{
|
{
|
||||||
Prompt: request.Prompt,
|
Prompt: request.Prompt,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Parameters: GeminiImageParameters{
|
Parameters: dto.GeminiImageParameters{
|
||||||
SampleCount: request.N,
|
SampleCount: request.N,
|
||||||
AspectRatio: aspectRatio,
|
AspectRatio: aspectRatio,
|
||||||
PersonGeneration: "allow_adult", // default allow adult
|
PersonGeneration: "allow_adult", // default allow adult
|
||||||
@@ -138,9 +158,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only process the first input
|
// only process the first input
|
||||||
geminiRequest := GeminiEmbeddingRequest{
|
geminiRequest := dto.GeminiEmbeddingRequest{
|
||||||
Content: GeminiChatContent{
|
Content: dto.GeminiChatContent{
|
||||||
Parts: []GeminiPart{
|
Parts: []dto.GeminiPart{
|
||||||
{
|
{
|
||||||
Text: inputs[0],
|
Text: inputs[0],
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -28,7 +29,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析为 Gemini 原生响应格式
|
// 解析为 Gemini 原生响应格式
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse dto.GeminiChatResponse
|
||||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
@@ -71,7 +72,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
responseText := strings.Builder{}
|
responseText := strings.Builder{}
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse dto.GeminiChatResponse
|
||||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
@@ -110,10 +111,14 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
}
|
}
|
||||||
|
info.SendResponseCount++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if info.SendResponseCount == 0 {
|
||||||
|
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
if imageCount != 0 {
|
if imageCount != 0 {
|
||||||
if usage.CompletionTokens == 0 {
|
if usage.CompletionTokens == 0 {
|
||||||
usage.CompletionTokens = imageCount * 258
|
usage.CompletionTokens = imageCount * 258
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int {
|
|||||||
return budget
|
return budget
|
||||||
}
|
}
|
||||||
|
|
||||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
modelName := info.UpstreamModelName
|
modelName := info.UpstreamModelName
|
||||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||||
@@ -93,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
|||||||
if len(parts) == 2 && parts[1] != "" {
|
if len(parts) == 2 && parts[1] != "" {
|
||||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||||
IncludeThoughts: true,
|
IncludeThoughts: true,
|
||||||
}
|
}
|
||||||
@@ -113,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isUnsupported {
|
if isUnsupported {
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||||
IncludeThoughts: true,
|
IncludeThoughts: true,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||||
IncludeThoughts: true,
|
IncludeThoughts: true,
|
||||||
}
|
}
|
||||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||||
@@ -128,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||||
if !isNew25Pro {
|
if !isNew25Pro {
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||||
ThinkingBudget: common.GetPointer(0),
|
ThinkingBudget: common.GetPointer(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||||
|
|
||||||
geminiRequest := GeminiChatRequest{
|
geminiRequest := dto.GeminiChatRequest{
|
||||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||||
GenerationConfig: GeminiChatGenerationConfig{
|
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
MaxOutputTokens: textRequest.MaxTokens,
|
MaxOutputTokens: textRequest.MaxTokens,
|
||||||
@@ -158,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
|
|
||||||
ThinkingAdaptor(&geminiRequest, info)
|
ThinkingAdaptor(&geminiRequest, info)
|
||||||
|
|
||||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||||
for _, category := range SafetySettingList {
|
for _, category := range SafetySettingList {
|
||||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
|
||||||
Category: category,
|
Category: category,
|
||||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||||
})
|
})
|
||||||
@@ -198,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
functions = append(functions, tool.Function)
|
functions = append(functions, tool.Function)
|
||||||
}
|
}
|
||||||
if codeExecution {
|
if codeExecution {
|
||||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||||
CodeExecution: make(map[string]string),
|
CodeExecution: make(map[string]string),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if googleSearch {
|
if googleSearch {
|
||||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||||
GoogleSearch: make(map[string]string),
|
GoogleSearch: make(map[string]string),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if len(functions) > 0 {
|
if len(functions) > 0 {
|
||||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
|
||||||
FunctionDeclarations: functions,
|
FunctionDeclarations: functions,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -238,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
continue
|
continue
|
||||||
} else if message.Role == "tool" || message.Role == "function" {
|
} else if message.Role == "tool" || message.Role == "function" {
|
||||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -265,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
functionResp := &FunctionResponse{
|
functionResp := &dto.GeminiFunctionResponse{
|
||||||
Name: name,
|
Name: name,
|
||||||
Response: contentMap,
|
Response: contentMap,
|
||||||
}
|
}
|
||||||
|
|
||||||
*parts = append(*parts, GeminiPart{
|
*parts = append(*parts, dto.GeminiPart{
|
||||||
FunctionResponse: functionResp,
|
FunctionResponse: functionResp,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var parts []GeminiPart
|
var parts []dto.GeminiPart
|
||||||
content := GeminiChatContent{
|
content := dto.GeminiChatContent{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
}
|
}
|
||||||
// isToolCall := false
|
// isToolCall := false
|
||||||
@@ -290,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
toolCall := GeminiPart{
|
toolCall := dto.GeminiPart{
|
||||||
FunctionCall: &FunctionCall{
|
FunctionCall: &dto.FunctionCall{
|
||||||
FunctionName: call.Function.Name,
|
FunctionName: call.Function.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
},
|
},
|
||||||
@@ -308,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
if part.Text == "" {
|
if part.Text == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
Text: part.Text,
|
Text: part.Text,
|
||||||
})
|
})
|
||||||
} else if part.Type == dto.ContentTypeImageURL {
|
} else if part.Type == dto.ContentTypeImageURL {
|
||||||
@@ -331,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||||
}
|
}
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &dto.GeminiInlineData{
|
||||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||||
Data: fileData.Base64Data,
|
Data: fileData.Base64Data,
|
||||||
},
|
},
|
||||||
@@ -342,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &dto.GeminiInlineData{
|
||||||
MimeType: format,
|
MimeType: format,
|
||||||
Data: base64String,
|
Data: base64String,
|
||||||
},
|
},
|
||||||
@@ -357,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &dto.GeminiInlineData{
|
||||||
MimeType: format,
|
MimeType: format,
|
||||||
Data: base64String,
|
Data: base64String,
|
||||||
},
|
},
|
||||||
@@ -371,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, dto.GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &dto.GeminiInlineData{
|
||||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||||
Data: base64String,
|
Data: base64String,
|
||||||
},
|
},
|
||||||
@@ -392,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(system_content) > 0 {
|
if len(system_content) > 0 {
|
||||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
|
||||||
Parts: []GeminiPart{
|
Parts: []dto.GeminiPart{
|
||||||
{
|
{
|
||||||
Text: strings.Join(system_content, "\n"),
|
Text: strings.Join(system_content, "\n"),
|
||||||
},
|
},
|
||||||
@@ -636,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||||
var argsBytes []byte
|
var argsBytes []byte
|
||||||
var err error
|
var err error
|
||||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||||
@@ -658,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: helper.GetResponseID(c),
|
Id: helper.GetResponseID(c),
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@@ -725,10 +725,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
|
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||||
isStop := false
|
isStop := false
|
||||||
hasImage := false
|
|
||||||
for _, candidate := range geminiResponse.Candidates {
|
for _, candidate := range geminiResponse.Candidates {
|
||||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||||
isStop = true
|
isStop = true
|
||||||
@@ -759,7 +758,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
|||||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||||
imgText := ""
|
imgText := ""
|
||||||
texts = append(texts, imgText)
|
texts = append(texts, imgText)
|
||||||
hasImage = true
|
|
||||||
}
|
}
|
||||||
} else if part.FunctionCall != nil {
|
} else if part.FunctionCall != nil {
|
||||||
isTools = true
|
isTools = true
|
||||||
@@ -796,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
|||||||
var response dto.ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
response.Choices = choices
|
response.Choices = choices
|
||||||
return &response, isStop, hasImage
|
return &response, isStop
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||||
@@ -824,23 +822,31 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|||||||
// responseText := ""
|
// responseText := ""
|
||||||
id := helper.GetResponseID(c)
|
id := helper.GetResponseID(c)
|
||||||
createAt := common.GetTimestamp()
|
createAt := common.GetTimestamp()
|
||||||
|
responseText := strings.Builder{}
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
var imageCount int
|
var imageCount int
|
||||||
|
|
||||||
respCount := 0
|
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse dto.GeminiChatResponse
|
||||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
for _, candidate := range geminiResponse.Candidates {
|
||||||
if hasImage {
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
imageCount++
|
imageCount++
|
||||||
}
|
}
|
||||||
|
if part.Text != "" {
|
||||||
|
responseText.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||||
|
|
||||||
response.Id = id
|
response.Id = id
|
||||||
response.Created = createAt
|
response.Created = createAt
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
@@ -858,7 +864,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if respCount == 0 {
|
if info.SendResponseCount == 0 {
|
||||||
// send first response
|
// send first response
|
||||||
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
|
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -873,11 +879,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|||||||
if isStop {
|
if isStop {
|
||||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
|
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
|
||||||
}
|
}
|
||||||
respCount++
|
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
if respCount == 0 {
|
if info.SendResponseCount == 0 {
|
||||||
// 空补全,报错不计费
|
// 空补全,报错不计费
|
||||||
// empty response, throw an error
|
// empty response, throw an error
|
||||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||||
@@ -892,6 +897,16 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|||||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||||
|
|
||||||
|
if usage.CompletionTokens == 0 {
|
||||||
|
str := responseText.String()
|
||||||
|
if len(str) > 0 {
|
||||||
|
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
// 空补全,不需要使用量
|
||||||
|
usage = &dto.Usage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
err := handleFinalStream(c, info, response)
|
err := handleFinalStream(c, info, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -913,7 +928,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println(string(responseBody))
|
println(string(responseBody))
|
||||||
}
|
}
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse dto.GeminiChatResponse
|
||||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
@@ -959,7 +974,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|||||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
var geminiResponse GeminiEmbeddingResponse
|
var geminiResponse dto.GeminiEmbeddingResponse
|
||||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -1005,7 +1020,7 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
var geminiResponse GeminiImageResponse
|
var geminiResponse dto.GeminiImageResponse
|
||||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -13,11 +12,18 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
openaiAdaptor := openai.Adaptor{}
|
openaiAdaptor := openai.Adaptor{}
|
||||||
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
||||||
|
|||||||
@@ -34,6 +34,15 @@ type Adaptor struct {
|
|||||||
ResponseFormat string
|
ResponseFormat string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||||
|
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||||
|
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
//if !strings.Contains(request.Model, "claude") {
|
//if !strings.Contains(request.Model, "claude") {
|
||||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||||
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
|
||||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||||
}
|
}
|
||||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -16,11 +18,14 @@ import (
|
|||||||
// 辅助函数
|
// 辅助函数
|
||||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||||
info.SendResponseCount++
|
info.SendResponseCount++
|
||||||
|
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case relaycommon.RelayFormatOpenAI:
|
||||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||||
case relaycommon.RelayFormatClaude:
|
case relaycommon.RelayFormatClaude:
|
||||||
return handleClaudeFormat(c, data, info)
|
return handleClaudeFormat(c, data, info)
|
||||||
|
case relaycommon.RelayFormatGemini:
|
||||||
|
return handleGeminiFormat(c, data, info)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -41,6 +46,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||||
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||||
|
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||||
|
|
||||||
|
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||||
|
if geminiResponse == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// send gemini format response
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||||
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
return errors.New("streaming error: flusher not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
@@ -185,6 +220,37 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
|||||||
for _, resp := range claudeResponses {
|
for _, resp := range claudeResponses {
|
||||||
_ = helper.ClaudeData(c, *resp)
|
_ = helper.ClaudeData(c, *resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case relaycommon.RelayFormatGemini:
|
||||||
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
|
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||||
|
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||||
|
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||||
|
// 暂不知是否有程序会不兼容。
|
||||||
|
|
||||||
|
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||||
|
|
||||||
|
// openai 流响应开头的空数据
|
||||||
|
if geminiResponse == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling gemini response: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送最终的 Gemini 响应
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||||
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -223,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
responseBody = claudeRespStr
|
responseBody = claudeRespStr
|
||||||
|
case relaycommon.RelayFormatGemini:
|
||||||
|
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||||
|
geminiRespStr, err := common.Marshal(geminiResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
|
}
|
||||||
|
responseBody = geminiRespStr
|
||||||
}
|
}
|
||||||
|
|
||||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||||
adaptor := openai.Adaptor{}
|
adaptor := openai.Adaptor{}
|
||||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ type Adaptor struct {
|
|||||||
Timestamp int64
|
Timestamp int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -44,6 +44,11 @@ type Adaptor struct {
|
|||||||
AccountCredentials Credentials
|
AccountCredentials Credentials
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||||
|
geminiAdaptor := gemini.Adaptor{}
|
||||||
|
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||||
c.Set("request_model", v)
|
c.Set("request_model", v)
|
||||||
|
|||||||
@@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
|||||||
})
|
})
|
||||||
|
|
||||||
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||||
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
var cacheKey string
|
||||||
|
if info.ChannelIsMultiKey {
|
||||||
|
cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
|
||||||
|
} else {
|
||||||
|
cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||||
|
}
|
||||||
val, err := Cache.Get(cacheKey)
|
val, err := Cache.Get(cacheKey)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return val.(string), nil
|
return val.(string), nil
|
||||||
|
|||||||
@@ -23,6 +23,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -19,6 +19,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
//panic("implement me")
|
//panic("implement me")
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ type Adaptor struct {
|
|||||||
request *dto.GeneralOpenAIRequest
|
request *dto.GeneralOpenAIRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -18,6 +18,11 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
//TODO implement me
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
//TODO implement me
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
|
|||||||
@@ -62,6 +62,8 @@ type ResponsesUsageInfo struct {
|
|||||||
type RelayInfo struct {
|
type RelayInfo struct {
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
|
ChannelIsMultiKey bool // 是否多密钥
|
||||||
|
ChannelMultiKeyIndex int // 多密钥索引
|
||||||
TokenId int
|
TokenId int
|
||||||
TokenKey string
|
TokenKey string
|
||||||
UserId int
|
UserId int
|
||||||
@@ -260,6 +262,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
IsFirstThinkingContent: true,
|
IsFirstThinkingContent: true,
|
||||||
SendLastThinkingContent: false,
|
SendLastThinkingContent: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||||
|
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||||
info.IsPlayground = true
|
info.IsPlayground = true
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||||
request := &gemini.GeminiChatRequest{}
|
request := &dto.GeminiChatRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, request)
|
err := common.UnmarshalBodyReusable(c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range textRequest.Contents {
|
for _, content := range textRequest.Contents {
|
||||||
for _, part := range content.Parts {
|
for _, part := range content.Parts {
|
||||||
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||||
// 计算输入 token 数量
|
// 计算输入 token 数量
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range req.Contents {
|
for _, content := range req.Contents {
|
||||||
@@ -78,9 +78,13 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
|||||||
return inputTokens
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
|
||||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0
|
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
|
||||||
|
if configBudget != nil && *configBudget == 0 {
|
||||||
|
// 如果思考预算为 0,则认为是非思考请求
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -202,7 +206,12 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
|||||||
}
|
}
|
||||||
requestBody = bytes.NewReader(body)
|
requestBody = bytes.NewReader(body)
|
||||||
} else {
|
} else {
|
||||||
jsonData, err := common.Marshal(req)
|
// 使用 ConvertGeminiRequest 转换请求格式
|
||||||
|
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
jsonData, err := common.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,10 +305,10 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
if userQuota <= 0 {
|
if userQuota <= 0 {
|
||||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry())
|
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry())
|
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||||
}
|
}
|
||||||
relayInfo.UserQuota = userQuota
|
relayInfo.UserQuota = userQuota
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
@@ -332,7 +332,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry())
|
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||||
}
|
}
|
||||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"one-api/relay/channel/baidu"
|
"one-api/relay/channel/baidu"
|
||||||
"one-api/relay/channel/baidu_v2"
|
"one-api/relay/channel/baidu_v2"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
"one-api/relay/channel/claude_code"
|
|
||||||
"one-api/relay/channel/cloudflare"
|
"one-api/relay/channel/cloudflare"
|
||||||
"one-api/relay/channel/cohere"
|
"one-api/relay/channel/cohere"
|
||||||
"one-api/relay/channel/coze"
|
"one-api/relay/channel/coze"
|
||||||
@@ -99,8 +98,6 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &coze.Adaptor{}
|
return &coze.Adaptor{}
|
||||||
case constant.APITypeJimeng:
|
case constant.APITypeJimeng:
|
||||||
return &jimeng.Adaptor{}
|
return &jimeng.Adaptor{}
|
||||||
case constant.APITypeClaudeCode:
|
|
||||||
return &claude_code.Adaptor{}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,9 +120,6 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
|
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
|
||||||
channelRoute.GET("/tag/models", controller.GetTagModels)
|
channelRoute.GET("/tag/models", controller.GetTagModels)
|
||||||
channelRoute.POST("/copy/:id", controller.CopyChannel)
|
channelRoute.POST("/copy/:id", controller.CopyChannel)
|
||||||
// Claude OAuth路由
|
|
||||||
channelRoute.GET("/claude/oauth/url", controller.GenerateClaudeOAuthURL)
|
|
||||||
channelRoute.POST("/claude/oauth/exchange", controller.ExchangeClaudeOAuthCode)
|
|
||||||
}
|
}
|
||||||
tokenRoute := apiRouter.Group("/token")
|
tokenRoute := apiRouter.Group("/token")
|
||||||
tokenRoute.Use(middleware.UserAuth())
|
tokenRoute.Use(middleware.UserAuth())
|
||||||
|
|||||||
@@ -1,171 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Default OAuth configuration values
|
|
||||||
DefaultAuthorizeURL = "https://claude.ai/oauth/authorize"
|
|
||||||
DefaultTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
|
||||||
DefaultClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
|
||||||
DefaultRedirectURI = "https://console.anthropic.com/oauth/code/callback"
|
|
||||||
DefaultScopes = "user:inference"
|
|
||||||
)
|
|
||||||
|
|
||||||
// getOAuthValues returns OAuth configuration values from environment variables or defaults
|
|
||||||
func getOAuthValues() (authorizeURL, tokenURL, clientID, redirectURI, scopes string) {
|
|
||||||
authorizeURL = os.Getenv("CLAUDE_AUTHORIZE_URL")
|
|
||||||
if authorizeURL == "" {
|
|
||||||
authorizeURL = DefaultAuthorizeURL
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenURL = os.Getenv("CLAUDE_TOKEN_URL")
|
|
||||||
if tokenURL == "" {
|
|
||||||
tokenURL = DefaultTokenURL
|
|
||||||
}
|
|
||||||
|
|
||||||
clientID = os.Getenv("CLAUDE_CLIENT_ID")
|
|
||||||
if clientID == "" {
|
|
||||||
clientID = DefaultClientID
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectURI = os.Getenv("CLAUDE_REDIRECT_URI")
|
|
||||||
if redirectURI == "" {
|
|
||||||
redirectURI = DefaultRedirectURI
|
|
||||||
}
|
|
||||||
|
|
||||||
scopes = os.Getenv("CLAUDE_SCOPES")
|
|
||||||
if scopes == "" {
|
|
||||||
scopes = DefaultScopes
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type OAuth2Credentials struct {
|
|
||||||
AuthURL string `json:"auth_url"`
|
|
||||||
CodeVerifier string `json:"code_verifier"`
|
|
||||||
State string `json:"state"`
|
|
||||||
CodeChallenge string `json:"code_challenge"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClaudeOAuthConfig returns the Claude OAuth2 configuration
|
|
||||||
func GetClaudeOAuthConfig() *oauth2.Config {
|
|
||||||
authorizeURL, tokenURL, clientID, redirectURI, scopes := getOAuthValues()
|
|
||||||
|
|
||||||
return &oauth2.Config{
|
|
||||||
ClientID: clientID,
|
|
||||||
RedirectURL: redirectURI,
|
|
||||||
Scopes: strings.Split(scopes, " "),
|
|
||||||
Endpoint: oauth2.Endpoint{
|
|
||||||
AuthURL: authorizeURL,
|
|
||||||
TokenURL: tokenURL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getOAuthConfig is kept for backward compatibility
|
|
||||||
func getOAuthConfig() *oauth2.Config {
|
|
||||||
return GetClaudeOAuthConfig()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateOAuthParams generates OAuth authorization URL and related parameters
|
|
||||||
func GenerateOAuthParams() (*OAuth2Credentials, error) {
|
|
||||||
config := getOAuthConfig()
|
|
||||||
|
|
||||||
// Generate PKCE parameters
|
|
||||||
codeVerifier := oauth2.GenerateVerifier()
|
|
||||||
state := oauth2.GenerateVerifier() // Reuse generator as state
|
|
||||||
|
|
||||||
// Generate authorization URL
|
|
||||||
authURL := config.AuthCodeURL(state,
|
|
||||||
oauth2.S256ChallengeOption(codeVerifier),
|
|
||||||
oauth2.SetAuthURLParam("code", "true"), // Claude-specific parameter
|
|
||||||
)
|
|
||||||
|
|
||||||
return &OAuth2Credentials{
|
|
||||||
AuthURL: authURL,
|
|
||||||
CodeVerifier: codeVerifier,
|
|
||||||
State: state,
|
|
||||||
CodeChallenge: oauth2.S256ChallengeFromVerifier(codeVerifier),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExchangeCode
|
|
||||||
func ExchangeCode(authorizationCode, codeVerifier, state string, client *http.Client) (*oauth2.Token, error) {
|
|
||||||
config := getOAuthConfig()
|
|
||||||
|
|
||||||
if strings.Contains(authorizationCode, "#") {
|
|
||||||
parts := strings.Split(authorizationCode, "#")
|
|
||||||
if len(parts) > 0 {
|
|
||||||
authorizationCode = parts[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
if client != nil {
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := config.Exchange(ctx, authorizationCode,
|
|
||||||
oauth2.VerifierOption(codeVerifier),
|
|
||||||
oauth2.SetAuthURLParam("state", state),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("token exchange failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseAuthorizationCode(input string) (string, error) {
|
|
||||||
if input == "" {
|
|
||||||
return "", fmt.Errorf("please provide a valid authorization code")
|
|
||||||
}
|
|
||||||
// URLs are not allowed
|
|
||||||
if strings.Contains(input, "http") || strings.Contains(input, "https") {
|
|
||||||
return "", fmt.Errorf("authorization code cannot contain URLs")
|
|
||||||
}
|
|
||||||
|
|
||||||
return input, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClaudeHTTPClient returns a configured HTTP client for Claude OAuth operations
|
|
||||||
func GetClaudeHTTPClient() *http.Client {
|
|
||||||
return &http.Client{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshClaudeToken refreshes a Claude OAuth token using the refresh token
|
|
||||||
func RefreshClaudeToken(accessToken, refreshToken string) (*oauth2.Token, error) {
|
|
||||||
config := GetClaudeOAuthConfig()
|
|
||||||
|
|
||||||
// Create token from current values
|
|
||||||
currentToken := &oauth2.Token{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
if client := GetClaudeHTTPClient(); client != nil {
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh the token
|
|
||||||
newToken, err := config.TokenSource(ctx, currentToken).Token()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to refresh Claude token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newToken, nil
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
|
||||||
"one-api/model"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StartClaudeTokenRefreshScheduler starts the scheduled token refresh for Claude Code channels
|
|
||||||
func StartClaudeTokenRefreshScheduler() {
|
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
|
||||||
gopool.Go(func() {
|
|
||||||
defer ticker.Stop()
|
|
||||||
for range ticker.C {
|
|
||||||
RefreshClaudeCodeTokens()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
common.SysLog("Claude Code token refresh scheduler started (5 minute interval)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshClaudeCodeTokens refreshes tokens for all active Claude Code channels
|
|
||||||
func RefreshClaudeCodeTokens() {
|
|
||||||
var channels []model.Channel
|
|
||||||
|
|
||||||
// Get all active Claude Code channels
|
|
||||||
err := model.DB.Where("type = ? AND status = ?", constant.ChannelTypeClaudeCode, common.ChannelStatusEnabled).Find(&channels).Error
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("Failed to get Claude Code channels: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
refreshCount := 0
|
|
||||||
for _, channel := range channels {
|
|
||||||
if refreshTokenForChannel(&channel) {
|
|
||||||
refreshCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if refreshCount > 0 {
|
|
||||||
common.SysLog(fmt.Sprintf("Successfully refreshed %d Claude Code channel tokens", refreshCount))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshTokenForChannel attempts to refresh token for a single channel
|
|
||||||
func refreshTokenForChannel(channel *model.Channel) bool {
|
|
||||||
// Parse key in format: accesstoken|refreshtoken
|
|
||||||
if channel.Key == "" || !strings.Contains(channel.Key, "|") {
|
|
||||||
common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(channel.Key, "|")
|
|
||||||
if len(parts) < 2 {
|
|
||||||
common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken := parts[0]
|
|
||||||
refreshToken := parts[1]
|
|
||||||
|
|
||||||
if refreshToken == "" {
|
|
||||||
common.SysError(fmt.Sprintf("Channel %d has empty refresh token", channel.Id))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if token needs refresh (refresh 30 minutes before expiry)
|
|
||||||
// if !shouldRefreshToken(accessToken) {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Use shared refresh function
|
|
||||||
newToken, err := RefreshClaudeToken(accessToken, refreshToken)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("Failed to refresh token for channel %d: %s", channel.Id, err.Error()))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update channel with new tokens
|
|
||||||
newKey := fmt.Sprintf("%s|%s", newToken.AccessToken, newToken.RefreshToken)
|
|
||||||
|
|
||||||
err = model.DB.Model(channel).Update("key", newKey).Error
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("Failed to update channel %d with new token: %s", channel.Id, err.Error()))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
common.SysLog(fmt.Sprintf("Successfully refreshed token for Claude Code channel %d (%s)", channel.Id, channel.Name))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -448,3 +448,353 @@ func toJSONString(v interface{}) string {
|
|||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||||
|
openaiRequest := &dto.GeneralOpenAIRequest{
|
||||||
|
Model: info.UpstreamModelName,
|
||||||
|
Stream: info.IsStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换 messages
|
||||||
|
var messages []dto.Message
|
||||||
|
for _, content := range geminiRequest.Contents {
|
||||||
|
message := dto.Message{
|
||||||
|
Role: convertGeminiRoleToOpenAI(content.Role),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 parts
|
||||||
|
var mediaContents []dto.MediaContent
|
||||||
|
var toolCalls []dto.ToolCallRequest
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
mediaContent := dto.MediaContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: part.Text,
|
||||||
|
}
|
||||||
|
mediaContents = append(mediaContents, mediaContent)
|
||||||
|
} else if part.InlineData != nil {
|
||||||
|
mediaContent := dto.MediaContent{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageUrl: &dto.MessageImageUrl{
|
||||||
|
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
|
||||||
|
Detail: "auto",
|
||||||
|
MimeType: part.InlineData.MimeType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mediaContents = append(mediaContents, mediaContent)
|
||||||
|
} else if part.FileData != nil {
|
||||||
|
mediaContent := dto.MediaContent{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageUrl: &dto.MessageImageUrl{
|
||||||
|
Url: part.FileData.FileUri,
|
||||||
|
Detail: "auto",
|
||||||
|
MimeType: part.FileData.MimeType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mediaContents = append(mediaContents, mediaContent)
|
||||||
|
} else if part.FunctionCall != nil {
|
||||||
|
// 处理 Gemini 的工具调用
|
||||||
|
toolCall := dto.ToolCallRequest{
|
||||||
|
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
|
||||||
|
Type: "function",
|
||||||
|
Function: dto.FunctionRequest{
|
||||||
|
Name: part.FunctionCall.FunctionName,
|
||||||
|
Arguments: toJSONString(part.FunctionCall.Arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
} else if part.FunctionResponse != nil {
|
||||||
|
// 处理 Gemini 的工具响应,创建单独的 tool 消息
|
||||||
|
toolMessage := dto.Message{
|
||||||
|
Role: "tool",
|
||||||
|
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
|
||||||
|
}
|
||||||
|
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
|
||||||
|
messages = append(messages, toolMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置消息内容
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
// 如果有工具调用,设置工具调用
|
||||||
|
message.SetToolCalls(toolCalls)
|
||||||
|
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
|
||||||
|
// 如果只有一个文本内容,直接设置字符串
|
||||||
|
message.Content = mediaContents[0].Text
|
||||||
|
} else if len(mediaContents) > 0 {
|
||||||
|
// 如果有多个内容或包含媒体,设置为数组
|
||||||
|
message.SetMediaContent(mediaContents)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只有当消息有内容或工具调用时才添加
|
||||||
|
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
|
||||||
|
messages = append(messages, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiRequest.Messages = messages
|
||||||
|
|
||||||
|
if geminiRequest.GenerationConfig.Temperature != nil {
|
||||||
|
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
|
||||||
|
}
|
||||||
|
if geminiRequest.GenerationConfig.TopP > 0 {
|
||||||
|
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
|
||||||
|
}
|
||||||
|
if geminiRequest.GenerationConfig.TopK > 0 {
|
||||||
|
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
|
||||||
|
}
|
||||||
|
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||||
|
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
|
||||||
|
}
|
||||||
|
// gemini stop sequences 最多 5 个,openai stop 最多 4 个
|
||||||
|
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
|
||||||
|
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
|
||||||
|
}
|
||||||
|
if geminiRequest.GenerationConfig.CandidateCount > 0 {
|
||||||
|
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换工具调用
|
||||||
|
if len(geminiRequest.Tools) > 0 {
|
||||||
|
var tools []dto.ToolCallRequest
|
||||||
|
for _, tool := range geminiRequest.Tools {
|
||||||
|
if tool.FunctionDeclarations != nil {
|
||||||
|
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||||
|
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||||
|
if ok {
|
||||||
|
for _, function := range functionDeclarations {
|
||||||
|
openAITool := dto.ToolCallRequest{
|
||||||
|
Type: "function",
|
||||||
|
Function: dto.FunctionRequest{
|
||||||
|
Name: function.Name,
|
||||||
|
Description: function.Description,
|
||||||
|
Parameters: function.Parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tools = append(tools, openAITool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(tools) > 0 {
|
||||||
|
openaiRequest.Tools = tools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// gemini system instructions
|
||||||
|
if geminiRequest.SystemInstructions != nil {
|
||||||
|
// 将系统指令作为第一条消息插入
|
||||||
|
systemMessage := dto.Message{
|
||||||
|
Role: "system",
|
||||||
|
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
|
||||||
|
}
|
||||||
|
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return openaiRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertGeminiRoleToOpenAI(geminiRole string) string {
|
||||||
|
switch geminiRole {
|
||||||
|
case "user":
|
||||||
|
return "user"
|
||||||
|
case "model":
|
||||||
|
return "assistant"
|
||||||
|
case "function":
|
||||||
|
return "function"
|
||||||
|
default:
|
||||||
|
return "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
|
||||||
|
var texts []string
|
||||||
|
for _, part := range parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
texts = append(texts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(texts, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
|
||||||
|
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||||
|
geminiResponse := &dto.GeminiChatResponse{
|
||||||
|
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||||
|
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||||
|
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: openAIResponse.PromptTokens,
|
||||||
|
CandidatesTokenCount: openAIResponse.CompletionTokens,
|
||||||
|
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, choice := range openAIResponse.Choices {
|
||||||
|
candidate := dto.GeminiChatCandidate{
|
||||||
|
Index: int64(choice.Index),
|
||||||
|
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置结束原因
|
||||||
|
var finishReason string
|
||||||
|
switch choice.FinishReason {
|
||||||
|
case "stop":
|
||||||
|
finishReason = "STOP"
|
||||||
|
case "length":
|
||||||
|
finishReason = "MAX_TOKENS"
|
||||||
|
case "content_filter":
|
||||||
|
finishReason = "SAFETY"
|
||||||
|
case "tool_calls":
|
||||||
|
finishReason = "STOP"
|
||||||
|
default:
|
||||||
|
finishReason = "STOP"
|
||||||
|
}
|
||||||
|
candidate.FinishReason = &finishReason
|
||||||
|
|
||||||
|
// 转换消息内容
|
||||||
|
content := dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: make([]dto.GeminiPart, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理工具调用
|
||||||
|
toolCalls := choice.Message.ParseToolCalls()
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
// 解析参数
|
||||||
|
var args map[string]interface{}
|
||||||
|
if toolCall.Function.Arguments != "" {
|
||||||
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||||
|
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
part := dto.GeminiPart{
|
||||||
|
FunctionCall: &dto.FunctionCall{
|
||||||
|
FunctionName: toolCall.Function.Name,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
content.Parts = append(content.Parts, part)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 处理文本内容
|
||||||
|
textContent := choice.Message.StringContent()
|
||||||
|
if textContent != "" {
|
||||||
|
part := dto.GeminiPart{
|
||||||
|
Text: textContent,
|
||||||
|
}
|
||||||
|
content.Parts = append(content.Parts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate.Content = content
|
||||||
|
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return geminiResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
|
||||||
|
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
|
||||||
|
// 检查是否有实际内容或结束标志
|
||||||
|
hasContent := false
|
||||||
|
hasFinishReason := false
|
||||||
|
for _, choice := range openAIResponse.Choices {
|
||||||
|
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
|
||||||
|
hasContent = true
|
||||||
|
}
|
||||||
|
if choice.FinishReason != nil {
|
||||||
|
hasFinishReason = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
|
||||||
|
if !hasContent && !hasFinishReason {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiResponse := &dto.GeminiChatResponse{
|
||||||
|
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
|
||||||
|
PromptFeedback: dto.GeminiChatPromptFeedback{
|
||||||
|
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: info.PromptTokens,
|
||||||
|
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
|
||||||
|
TotalTokenCount: info.PromptTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, choice := range openAIResponse.Choices {
|
||||||
|
candidate := dto.GeminiChatCandidate{
|
||||||
|
Index: int64(choice.Index),
|
||||||
|
SafetyRatings: []dto.GeminiChatSafetyRating{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置结束原因
|
||||||
|
if choice.FinishReason != nil {
|
||||||
|
var finishReason string
|
||||||
|
switch *choice.FinishReason {
|
||||||
|
case "stop":
|
||||||
|
finishReason = "STOP"
|
||||||
|
case "length":
|
||||||
|
finishReason = "MAX_TOKENS"
|
||||||
|
case "content_filter":
|
||||||
|
finishReason = "SAFETY"
|
||||||
|
case "tool_calls":
|
||||||
|
finishReason = "STOP"
|
||||||
|
default:
|
||||||
|
finishReason = "STOP"
|
||||||
|
}
|
||||||
|
candidate.FinishReason = &finishReason
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换消息内容
|
||||||
|
content := dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: make([]dto.GeminiPart, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理工具调用
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
for _, toolCall := range choice.Delta.ToolCalls {
|
||||||
|
// 解析参数
|
||||||
|
var args map[string]interface{}
|
||||||
|
if toolCall.Function.Arguments != "" {
|
||||||
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||||
|
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
args = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
part := dto.GeminiPart{
|
||||||
|
FunctionCall: &dto.FunctionCall{
|
||||||
|
FunctionName: toolCall.Function.Name,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
content.Parts = append(content.Parts, part)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 处理文本内容
|
||||||
|
textContent := choice.Delta.GetContentString()
|
||||||
|
if textContent != "" {
|
||||||
|
part := dto.GeminiPart{
|
||||||
|
Text: textContent,
|
||||||
|
}
|
||||||
|
content.Parts = append(content.Parts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate.Content = content
|
||||||
|
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return geminiResponse
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,9 +13,6 @@ var AutomaticDisableKeywords = []string{
|
|||||||
"The security token included in the request is invalid",
|
"The security token included in the request is invalid",
|
||||||
"Operation not allowed",
|
"Operation not allowed",
|
||||||
"Your account is not authorized",
|
"Your account is not authorized",
|
||||||
// Claude Code
|
|
||||||
"Invalid bearer token",
|
|
||||||
"OAuth authentication is currently not allowed for this endpoint",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticDisableKeywordsToString() string {
|
func AutomaticDisableKeywordsToString() string {
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ type NewAPIError struct {
|
|||||||
Err error
|
Err error
|
||||||
RelayError any
|
RelayError any
|
||||||
skipRetry bool
|
skipRetry bool
|
||||||
|
recordErrorLog *bool
|
||||||
errorType ErrorType
|
errorType ErrorType
|
||||||
errorCode ErrorCode
|
errorCode ErrorCode
|
||||||
StatusCode int
|
StatusCode int
|
||||||
@@ -278,3 +279,20 @@ func ErrOptionWithSkipRetry() NewAPIErrorOptions {
|
|||||||
e.skipRetry = true
|
e.skipRetry = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
||||||
|
return func(e *NewAPIError) {
|
||||||
|
e.recordErrorLog = common.GetPointer(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||||
|
if e == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if e.recordErrorLog == nil {
|
||||||
|
// default to true if not set
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return *e.recordErrorLog
|
||||||
|
}
|
||||||
|
|||||||
@@ -65,7 +65,8 @@ const JSONEditor = ({
|
|||||||
const keyCount = Object.keys(parsed).length;
|
const keyCount = Object.keys(parsed).length;
|
||||||
return keyCount > 10 ? 'manual' : 'visual';
|
return keyCount > 10 ? 'manual' : 'visual';
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return 'visual';
|
// JSON无效时默认显示手动编辑模式
|
||||||
|
return 'manual';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 'visual';
|
return 'visual';
|
||||||
@@ -201,6 +202,18 @@ const JSONEditor = ({
|
|||||||
|
|
||||||
// 渲染键值对编辑器
|
// 渲染键值对编辑器
|
||||||
const renderKeyValueEditor = () => {
|
const renderKeyValueEditor = () => {
|
||||||
|
if (typeof jsonData !== 'object' || jsonData === null) {
|
||||||
|
return (
|
||||||
|
<div className="text-center py-6 px-4">
|
||||||
|
<div className="text-gray-400 mb-2">
|
||||||
|
<IconCode size={32} />
|
||||||
|
</div>
|
||||||
|
<Text type="tertiary" className="text-gray-500 text-sm">
|
||||||
|
{t('无效的JSON数据,请检查格式')}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
const entries = Object.entries(jsonData);
|
const entries = Object.entries(jsonData);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
For commercial licensing, please contact support@quantumnous.com
|
For commercial licensing, please contact support@quantumnous.com
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import React, { useEffect, useState, useRef, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
API,
|
API,
|
||||||
showError,
|
showError,
|
||||||
@@ -24,42 +26,38 @@ import {
|
|||||||
showSuccess,
|
showSuccess,
|
||||||
verifyJSON,
|
verifyJSON,
|
||||||
} from '../../../../helpers';
|
} from '../../../../helpers';
|
||||||
|
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
|
||||||
|
import { CHANNEL_OPTIONS } from '../../../../constants';
|
||||||
import {
|
import {
|
||||||
Avatar,
|
|
||||||
Banner,
|
|
||||||
Button,
|
|
||||||
Card,
|
|
||||||
Checkbox,
|
|
||||||
Col,
|
|
||||||
Form,
|
|
||||||
Highlight,
|
|
||||||
ImagePreview,
|
|
||||||
Input,
|
|
||||||
Modal,
|
|
||||||
Row,
|
|
||||||
SideSheet,
|
SideSheet,
|
||||||
Space,
|
Space,
|
||||||
Spin,
|
Spin,
|
||||||
Tag,
|
Button,
|
||||||
Typography,
|
Typography,
|
||||||
|
Checkbox,
|
||||||
|
Banner,
|
||||||
|
Modal,
|
||||||
|
ImagePreview,
|
||||||
|
Card,
|
||||||
|
Tag,
|
||||||
|
Avatar,
|
||||||
|
Form,
|
||||||
|
Row,
|
||||||
|
Col,
|
||||||
|
Highlight,
|
||||||
} from '@douyinfe/semi-ui';
|
} from '@douyinfe/semi-ui';
|
||||||
import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers';
|
import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers';
|
||||||
import ModelSelectModal from './ModelSelectModal';
|
import ModelSelectModal from './ModelSelectModal';
|
||||||
import JSONEditor from '../../../common/JSONEditor';
|
import JSONEditor from '../../../common/JSONEditor';
|
||||||
import { CHANNEL_OPTIONS, CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT } from '../../../../constants';
|
|
||||||
import {
|
import {
|
||||||
IconBolt,
|
|
||||||
IconClose,
|
|
||||||
IconCode,
|
|
||||||
IconGlobe,
|
|
||||||
IconSave,
|
IconSave,
|
||||||
|
IconClose,
|
||||||
IconServer,
|
IconServer,
|
||||||
IconSetting,
|
IconSetting,
|
||||||
|
IconCode,
|
||||||
|
IconGlobe,
|
||||||
|
IconBolt,
|
||||||
} from '@douyinfe/semi-icons';
|
} from '@douyinfe/semi-icons';
|
||||||
import React, { useEffect, useMemo, useRef, useState } from 'react';
|
|
||||||
|
|
||||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const { Text, Title } = Typography;
|
const { Text, Title } = Typography;
|
||||||
|
|
||||||
@@ -95,8 +93,6 @@ function type2secretPrompt(type) {
|
|||||||
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API,则直接输ApiKey';
|
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API,则直接输ApiKey';
|
||||||
case 51:
|
case 51:
|
||||||
return '按照如下格式输入: Access Key ID|Secret Access Key';
|
return '按照如下格式输入: Access Key ID|Secret Access Key';
|
||||||
case 53:
|
|
||||||
return '按照如下格式输入:AccessToken|RefreshToken';
|
|
||||||
default:
|
default:
|
||||||
return '请输入渠道对应的鉴权密钥';
|
return '请输入渠道对应的鉴权密钥';
|
||||||
}
|
}
|
||||||
@@ -149,10 +145,6 @@ const EditChannelModal = (props) => {
|
|||||||
const [customModel, setCustomModel] = useState('');
|
const [customModel, setCustomModel] = useState('');
|
||||||
const [modalImageUrl, setModalImageUrl] = useState('');
|
const [modalImageUrl, setModalImageUrl] = useState('');
|
||||||
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
||||||
const [showOAuthModal, setShowOAuthModal] = useState(false);
|
|
||||||
const [authorizationCode, setAuthorizationCode] = useState('');
|
|
||||||
const [oauthParams, setOauthParams] = useState(null);
|
|
||||||
const [isExchangingCode, setIsExchangingCode] = useState(false);
|
|
||||||
const [modelModalVisible, setModelModalVisible] = useState(false);
|
const [modelModalVisible, setModelModalVisible] = useState(false);
|
||||||
const [fetchedModels, setFetchedModels] = useState([]);
|
const [fetchedModels, setFetchedModels] = useState([]);
|
||||||
const formApiRef = useRef(null);
|
const formApiRef = useRef(null);
|
||||||
@@ -162,6 +154,7 @@ const EditChannelModal = (props) => {
|
|||||||
const [isMultiKeyChannel, setIsMultiKeyChannel] = useState(false);
|
const [isMultiKeyChannel, setIsMultiKeyChannel] = useState(false);
|
||||||
const [channelSearchValue, setChannelSearchValue] = useState('');
|
const [channelSearchValue, setChannelSearchValue] = useState('');
|
||||||
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
|
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
|
||||||
|
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||||
// 渠道额外设置状态
|
// 渠道额外设置状态
|
||||||
const [channelSettings, setChannelSettings] = useState({
|
const [channelSettings, setChannelSettings] = useState({
|
||||||
force_format: false,
|
force_format: false,
|
||||||
@@ -361,24 +354,6 @@ const EditChannelModal = (props) => {
|
|||||||
data.system_prompt = '';
|
data.system_prompt = '';
|
||||||
}
|
}
|
||||||
|
|
||||||
// 特殊处理Claude Code渠道的密钥拆分和系统提示词
|
|
||||||
if (data.type === 53) {
|
|
||||||
// 拆分密钥
|
|
||||||
if (data.key) {
|
|
||||||
const keyParts = data.key.split('|');
|
|
||||||
if (keyParts.length === 2) {
|
|
||||||
data.access_token = keyParts[0];
|
|
||||||
data.refresh_token = keyParts[1];
|
|
||||||
} else {
|
|
||||||
// 如果没有 | 分隔符,表示只有access token
|
|
||||||
data.access_token = data.key;
|
|
||||||
data.refresh_token = '';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 强制设置固定系统提示词
|
|
||||||
data.system_prompt = CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT;
|
|
||||||
}
|
|
||||||
|
|
||||||
setInputs(data);
|
setInputs(data);
|
||||||
if (formApiRef.current) {
|
if (formApiRef.current) {
|
||||||
formApiRef.current.setValues(data);
|
formApiRef.current.setValues(data);
|
||||||
@@ -502,72 +477,6 @@ const EditChannelModal = (props) => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// 生成OAuth授权URL
|
|
||||||
const handleGenerateOAuth = async () => {
|
|
||||||
try {
|
|
||||||
setLoading(true);
|
|
||||||
const res = await API.get('/api/channel/claude/oauth/url');
|
|
||||||
if (res.data.success) {
|
|
||||||
setOauthParams(res.data.data);
|
|
||||||
setShowOAuthModal(true);
|
|
||||||
showSuccess(t('OAuth授权URL生成成功'));
|
|
||||||
} else {
|
|
||||||
showError(res.data.message || t('生成OAuth授权URL失败'));
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showError(t('生成OAuth授权URL失败:') + error.message);
|
|
||||||
} finally {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// 交换授权码
|
|
||||||
const handleExchangeCode = async () => {
|
|
||||||
if (!authorizationCode.trim()) {
|
|
||||||
showError(t('请输入授权码'));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!oauthParams) {
|
|
||||||
showError(t('OAuth参数丢失,请重新生成'));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
setIsExchangingCode(true);
|
|
||||||
const res = await API.post('/api/channel/claude/oauth/exchange', {
|
|
||||||
authorization_code: authorizationCode,
|
|
||||||
code_verifier: oauthParams.code_verifier,
|
|
||||||
state: oauthParams.state,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (res.data.success) {
|
|
||||||
const tokenData = res.data.data;
|
|
||||||
// 自动填充access token和refresh token
|
|
||||||
handleInputChange('access_token', tokenData.access_token);
|
|
||||||
handleInputChange('refresh_token', tokenData.refresh_token);
|
|
||||||
handleInputChange('key', `${tokenData.access_token}|${tokenData.refresh_token}`);
|
|
||||||
|
|
||||||
// 更新表单字段
|
|
||||||
if (formApiRef.current) {
|
|
||||||
formApiRef.current.setValue('access_token', tokenData.access_token);
|
|
||||||
formApiRef.current.setValue('refresh_token', tokenData.refresh_token);
|
|
||||||
}
|
|
||||||
|
|
||||||
setShowOAuthModal(false);
|
|
||||||
setAuthorizationCode('');
|
|
||||||
setOauthParams(null);
|
|
||||||
showSuccess(t('授权码交换成功,已自动填充tokens'));
|
|
||||||
} else {
|
|
||||||
showError(res.data.message || t('授权码交换失败'));
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showError(t('授权码交换失败:') + error.message);
|
|
||||||
} finally {
|
|
||||||
setIsExchangingCode(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const modelMap = new Map();
|
const modelMap = new Map();
|
||||||
|
|
||||||
@@ -652,6 +561,12 @@ const EditChannelModal = (props) => {
|
|||||||
pass_through_body_enabled: false,
|
pass_through_body_enabled: false,
|
||||||
system_prompt: '',
|
system_prompt: '',
|
||||||
});
|
});
|
||||||
|
// 重置密钥模式状态
|
||||||
|
setKeyMode('append');
|
||||||
|
// 清空表单中的key_mode字段
|
||||||
|
if (formApiRef.current) {
|
||||||
|
formApiRef.current.setValue('key_mode', undefined);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, [props.visible, channelId]);
|
}, [props.visible, channelId]);
|
||||||
|
|
||||||
@@ -817,6 +732,7 @@ const EditChannelModal = (props) => {
|
|||||||
res = await API.put(`/api/channel/`, {
|
res = await API.put(`/api/channel/`, {
|
||||||
...localInputs,
|
...localInputs,
|
||||||
id: parseInt(channelId),
|
id: parseInt(channelId),
|
||||||
|
key_mode: isMultiKeyChannel ? keyMode : undefined, // 只在多key模式下传递
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
res = await API.post(`/api/channel/`, {
|
res = await API.post(`/api/channel/`, {
|
||||||
@@ -879,8 +795,9 @@ const EditChannelModal = (props) => {
|
|||||||
const batchAllowed = !isEdit || isMultiKeyChannel;
|
const batchAllowed = !isEdit || isMultiKeyChannel;
|
||||||
const batchExtra = batchAllowed ? (
|
const batchExtra = batchAllowed ? (
|
||||||
<Space>
|
<Space>
|
||||||
|
{!isEdit && (
|
||||||
<Checkbox
|
<Checkbox
|
||||||
disabled={isEdit || inputs.type === 53}
|
disabled={isEdit}
|
||||||
checked={batch}
|
checked={batch}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
const checked = e.target.checked;
|
const checked = e.target.checked;
|
||||||
@@ -927,7 +844,10 @@ const EditChannelModal = (props) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
>{t('批量创建')}</Checkbox>
|
>
|
||||||
|
{t('批量创建')}
|
||||||
|
</Checkbox>
|
||||||
|
)}
|
||||||
{batch && (
|
{batch && (
|
||||||
<Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
|
<Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
|
||||||
setMultiToSingle(prev => !prev);
|
setMultiToSingle(prev => !prev);
|
||||||
@@ -1124,7 +1044,16 @@ const EditChannelModal = (props) => {
|
|||||||
autosize
|
autosize
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
onChange={(value) => handleInputChange('key', value)}
|
onChange={(value) => handleInputChange('key', value)}
|
||||||
extraText={batchExtra}
|
extraText={
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
|
||||||
|
<Text type="warning" size="small">
|
||||||
|
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
{batchExtra}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
showClear
|
showClear
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
@@ -1191,6 +1120,11 @@ const EditChannelModal = (props) => {
|
|||||||
<Text type="tertiary" size="small">
|
<Text type="tertiary" size="small">
|
||||||
{t('请输入完整的 JSON 格式密钥内容')}
|
{t('请输入完整的 JSON 格式密钥内容')}
|
||||||
</Text>
|
</Text>
|
||||||
|
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
|
||||||
|
<Text type="warning" size="small">
|
||||||
|
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
{batchExtra}
|
{batchExtra}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
@@ -1216,49 +1150,6 @@ const EditChannelModal = (props) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
) : inputs.type === 53 ? (
|
|
||||||
<>
|
|
||||||
<Form.Input
|
|
||||||
field='access_token'
|
|
||||||
label={isEdit ? t('Access Token(编辑模式下,保存的密钥不会显示)') : t('Access Token')}
|
|
||||||
placeholder={t('sk-ant-xxx')}
|
|
||||||
rules={isEdit ? [] : [{ required: true, message: t('请输入Access Token') }]}
|
|
||||||
autoComplete='new-password'
|
|
||||||
onChange={(value) => {
|
|
||||||
handleInputChange('access_token', value);
|
|
||||||
// 同时更新key字段,格式为access_token|refresh_token
|
|
||||||
const refreshToken = inputs.refresh_token || '';
|
|
||||||
handleInputChange('key', `${value}|${refreshToken}`);
|
|
||||||
}}
|
|
||||||
suffix={
|
|
||||||
<Button
|
|
||||||
size="small"
|
|
||||||
type="primary"
|
|
||||||
theme="light"
|
|
||||||
onClick={handleGenerateOAuth}
|
|
||||||
>
|
|
||||||
{t('生成OAuth授权码')}
|
|
||||||
</Button>
|
|
||||||
}
|
|
||||||
extraText={batchExtra}
|
|
||||||
showClear
|
|
||||||
/>
|
|
||||||
<Form.Input
|
|
||||||
field='refresh_token'
|
|
||||||
label={isEdit ? t('Refresh Token(编辑模式下,保存的密钥不会显示)') : t('Refresh Token')}
|
|
||||||
placeholder={t('sk-ant-xxx(可选)')}
|
|
||||||
rules={[]}
|
|
||||||
autoComplete='new-password'
|
|
||||||
onChange={(value) => {
|
|
||||||
handleInputChange('refresh_token', value);
|
|
||||||
// 同时更新key字段,格式为access_token|refresh_token
|
|
||||||
const accessToken = inputs.access_token || '';
|
|
||||||
handleInputChange('key', `${accessToken}|${value}`);
|
|
||||||
}}
|
|
||||||
extraText={batchExtra}
|
|
||||||
showClear
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
) : (
|
) : (
|
||||||
<Form.Input
|
<Form.Input
|
||||||
field='key'
|
field='key'
|
||||||
@@ -1267,13 +1158,44 @@ const EditChannelModal = (props) => {
|
|||||||
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
|
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
onChange={(value) => handleInputChange('key', value)}
|
onChange={(value) => handleInputChange('key', value)}
|
||||||
extraText={batchExtra}
|
extraText={
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
|
||||||
|
<Text type="warning" size="small">
|
||||||
|
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
{batchExtra}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
showClear
|
showClear
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{isEdit && isMultiKeyChannel && (
|
||||||
|
<Form.Select
|
||||||
|
field='key_mode'
|
||||||
|
label={t('密钥更新模式')}
|
||||||
|
placeholder={t('请选择密钥更新模式')}
|
||||||
|
optionList={[
|
||||||
|
{ label: t('追加到现有密钥'), value: 'append' },
|
||||||
|
{ label: t('覆盖现有密钥'), value: 'replace' },
|
||||||
|
]}
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
value={keyMode}
|
||||||
|
onChange={(value) => setKeyMode(value)}
|
||||||
|
extraText={
|
||||||
|
<Text type="tertiary" size="small">
|
||||||
|
{keyMode === 'replace'
|
||||||
|
? t('覆盖模式:将完全替换现有的所有密钥')
|
||||||
|
: t('追加模式:将新密钥添加到现有密钥列表末尾')
|
||||||
|
}
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{batch && multiToSingle && (
|
{batch && multiToSingle && (
|
||||||
<>
|
<>
|
||||||
<Form.Select
|
<Form.Select
|
||||||
@@ -1767,19 +1689,11 @@ const EditChannelModal = (props) => {
|
|||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
field='system_prompt'
|
field='system_prompt'
|
||||||
label={t('系统提示词')}
|
label={t('系统提示词')}
|
||||||
placeholder={inputs.type === 53 ? CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT : t('输入系统提示词,用户的系统提示词将优先于此设置')}
|
placeholder={t('输入系统提示词,用户的系统提示词将优先于此设置')}
|
||||||
onChange={(value) => {
|
onChange={(value) => handleChannelSettingsChange('system_prompt', value)}
|
||||||
if (inputs.type === 53) {
|
|
||||||
// Claude Code渠道系统提示词固定,不允许修改
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
handleChannelSettingsChange('system_prompt', value);
|
|
||||||
}}
|
|
||||||
disabled={inputs.type === 53}
|
|
||||||
value={inputs.type === 53 ? CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT : undefined}
|
|
||||||
autosize
|
autosize
|
||||||
showClear={inputs.type !== 53}
|
showClear
|
||||||
extraText={inputs.type === 53 ? t('Claude Code渠道系统提示词固定为官方CLI身份,不可修改') : t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')}
|
extraText={t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')}
|
||||||
/>
|
/>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
@@ -1803,68 +1717,6 @@ const EditChannelModal = (props) => {
|
|||||||
}}
|
}}
|
||||||
onCancel={() => setModelModalVisible(false)}
|
onCancel={() => setModelModalVisible(false)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{/* OAuth Authorization Modal */}
|
|
||||||
<Modal
|
|
||||||
title={t('生成Claude Code OAuth授权码')}
|
|
||||||
visible={showOAuthModal}
|
|
||||||
onCancel={() => {
|
|
||||||
setShowOAuthModal(false);
|
|
||||||
setAuthorizationCode('');
|
|
||||||
setOauthParams(null);
|
|
||||||
}}
|
|
||||||
onOk={handleExchangeCode}
|
|
||||||
okText={isExchangingCode ? t('交换中...') : t('确认')}
|
|
||||||
cancelText={t('取消')}
|
|
||||||
confirmLoading={isExchangingCode}
|
|
||||||
width={600}
|
|
||||||
>
|
|
||||||
<div className="space-y-4">
|
|
||||||
<div>
|
|
||||||
<Text className="text-sm font-medium mb-2 block">{t('请访问以下授权地址:')}</Text>
|
|
||||||
<div className="p-3 bg-gray-50 rounded-lg border">
|
|
||||||
<Text
|
|
||||||
link
|
|
||||||
underline
|
|
||||||
className="text-sm font-mono break-all cursor-pointer text-blue-600 hover:text-blue-800"
|
|
||||||
onClick={() => {
|
|
||||||
if (oauthParams?.auth_url) {
|
|
||||||
window.open(oauthParams.auth_url, '_blank');
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{oauthParams?.auth_url || t('正在生成授权地址...')}
|
|
||||||
</Text>
|
|
||||||
<div className="mt-2">
|
|
||||||
<Text
|
|
||||||
copyable={{ content: oauthParams?.auth_url }}
|
|
||||||
type="tertiary"
|
|
||||||
size="small"
|
|
||||||
>
|
|
||||||
{t('复制链接')}
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<Text className="text-sm font-medium mb-2 block">{t('授权后,请将获得的授权码粘贴到下方:')}</Text>
|
|
||||||
<Input
|
|
||||||
value={authorizationCode}
|
|
||||||
onChange={setAuthorizationCode}
|
|
||||||
placeholder={t('请输入授权码')}
|
|
||||||
showClear
|
|
||||||
style={{ width: '100%' }}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<Banner
|
|
||||||
type="info"
|
|
||||||
description={t('获得授权码后,系统将自动换取access token和refresh token并填充到表单中。')}
|
|
||||||
className="!rounded-lg"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</Modal>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -159,14 +159,6 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'purple',
|
color: 'purple',
|
||||||
label: 'Vidu',
|
label: 'Vidu',
|
||||||
},
|
},
|
||||||
{
|
|
||||||
value: 53,
|
|
||||||
color: 'indigo',
|
|
||||||
label: 'Claude Code',
|
|
||||||
},
|
|
||||||
];
|
];
|
||||||
|
|
||||||
export const MODEL_TABLE_PAGE_SIZE = 10;
|
export const MODEL_TABLE_PAGE_SIZE = 10;
|
||||||
|
|
||||||
// Claude Code 相关常量
|
|
||||||
export const CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT = "You are Claude Code, Anthropic's official CLI for Claude.";
|
|
||||||
|
|||||||
@@ -353,7 +353,6 @@ export function getChannelIcon(channelType) {
|
|||||||
return <Ollama size={iconSize} />;
|
return <Ollama size={iconSize} />;
|
||||||
case 14: // Anthropic Claude
|
case 14: // Anthropic Claude
|
||||||
case 33: // AWS Claude
|
case 33: // AWS Claude
|
||||||
case 53: // Claude Code
|
|
||||||
return <Claude.Color size={iconSize} />;
|
return <Claude.Color size={iconSize} />;
|
||||||
case 41: // Vertex AI
|
case 41: // Vertex AI
|
||||||
return <Gemini.Color size={iconSize} />;
|
return <Gemini.Color size={iconSize} />;
|
||||||
|
|||||||
Reference in New Issue
Block a user