Merge branch 'alpha' into refactor/model-pricing

This commit is contained in:
t0ng7u
2025-08-02 22:26:40 +08:00
64 changed files with 1052 additions and 937 deletions

View File

@@ -65,8 +65,6 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeCoze apiType = constant.APITypeCoze
case constant.ChannelTypeJimeng: case constant.ChannelTypeJimeng:
apiType = constant.APITypeJimeng apiType = constant.APITypeJimeng
case constant.ChannelTypeClaudeCode:
apiType = constant.APITypeClaudeCode
} }
if apiType == -1 { if apiType == -1 {
return constant.APITypeOpenAI, false return constant.APITypeOpenAI, false

View File

@@ -31,6 +31,5 @@ const (
APITypeXai APITypeXai
APITypeCoze APITypeCoze
APITypeJimeng APITypeJimeng
APITypeClaudeCode
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )

View File

@@ -50,7 +50,6 @@ const (
ChannelTypeKling = 50 ChannelTypeKling = 50
ChannelTypeJimeng = 51 ChannelTypeJimeng = 51
ChannelTypeVidu = 52 ChannelTypeVidu = 52
ChannelTypeClaudeCode = 53
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
@@ -109,5 +108,4 @@ var ChannelBaseURLs = []string{
"https://api.klingai.com", //50 "https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51 "https://visual.volcengineapi.com", //51
"https://api.vidu.cn", //52 "https://api.vidu.cn", //52
"https://api.anthropic.com", //53
} }

View File

@@ -36,11 +36,30 @@ type OpenAIModel struct {
Parent string `json:"parent"` Parent string `json:"parent"`
} }
type GoogleOpenAICompatibleModels []struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description,omitempty"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
MaxTemperature int `json:"maxTemperature,omitempty"`
}
type OpenAIModelsResponse struct { type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"` Data []OpenAIModel `json:"data"`
Success bool `json:"success"` Success bool `json:"success"`
} }
type GoogleOpenAICompatibleResponse struct {
Models []GoogleOpenAICompatibleModels `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
func parseStatusFilter(statusParam string) int { func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) { switch strings.ToLower(statusParam) {
case "enabled", "1": case "enabled", "1":
@@ -168,26 +187,59 @@ func FetchUpstreamModels(c *gin.Context) {
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
url := fmt.Sprintf("%s/v1/models", baseURL)
var url string
switch channel.Type { switch channel.Type {
case constant.ChannelTypeGemini: case constant.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
case constant.ChannelTypeAli: case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
default:
url = fmt.Sprintf("%s/v1/models", baseURL)
}
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
var body []byte
if channel.Type == constant.ChannelTypeGemini {
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
} else {
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
} }
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
var result OpenAIModelsResponse var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil { var parseSuccess bool
c.JSON(http.StatusOK, gin.H{
"success": false, // 适配特殊格式
"message": fmt.Sprintf("解析响应失败: %s", err.Error()), switch channel.Type {
}) case constant.ChannelTypeGemini:
return var googleResult GoogleOpenAICompatibleResponse
if err = json.Unmarshal(body, &googleResult); err == nil {
// 转换Google格式到OpenAI格式
for _, model := range googleResult.Models {
for _, gModel := range model {
result.Data = append(result.Data, OpenAIModel{
ID: gModel.Name,
})
}
}
parseSuccess = true
}
}
// 如果解析失败尝试OpenAI格式
if !parseSuccess {
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
} }
var ids []string var ids []string
@@ -669,6 +721,7 @@ func DeleteChannelBatch(c *gin.Context) {
type PatchChannel struct { type PatchChannel struct {
model.Channel model.Channel
MultiKeyMode *string `json:"multi_key_mode"` MultiKeyMode *string `json:"multi_key_mode"`
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
} }
func UpdateChannel(c *gin.Context) { func UpdateChannel(c *gin.Context) {
@@ -688,7 +741,7 @@ func UpdateChannel(c *gin.Context) {
return return
} }
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request. // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
originChannel, err := model.GetChannelById(channel.Id, false) originChannel, err := model.GetChannelById(channel.Id, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -704,6 +757,69 @@ func UpdateChannel(c *gin.Context) {
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" { if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode) channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
} }
// 处理多key模式下的密钥追加/覆盖逻辑
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
switch *channel.KeyMode {
case "append":
// 追加模式:将新密钥添加到现有密钥列表
if originChannel.Key != "" {
var newKeys []string
var existingKeys []string
// 解析现有密钥
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
// JSON数组格式
var arr []json.RawMessage
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
existingKeys = make([]string, len(arr))
for i, v := range arr {
existingKeys[i] = string(v)
}
}
} else {
// 换行分隔格式
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
}
// 处理 Vertex AI 的特殊情况
if channel.Type == constant.ChannelTypeVertexAi {
// 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "追加密钥解析失败: " + err.Error(),
})
return
}
newKeys = array
} else {
// 单个JSON密钥
newKeys = []string{channel.Key}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
} else {
// 普通渠道的处理
inputKeys := strings.Split(channel.Key, "\n")
for _, key := range inputKeys {
key = strings.TrimSpace(key)
if key != "" {
newKeys = append(newKeys, key)
}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
}
case "replace":
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
}
}
err = channel.Update() err = channel.Update()
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)

View File

@@ -1,73 +0,0 @@
package controller
import (
"net/http"
"one-api/common"
"one-api/service"
"github.com/gin-gonic/gin"
)
// ExchangeCodeRequest 授权码交换请求
type ExchangeCodeRequest struct {
AuthorizationCode string `json:"authorization_code" binding:"required"`
CodeVerifier string `json:"code_verifier" binding:"required"`
State string `json:"state" binding:"required"`
}
// GenerateClaudeOAuthURL 生成Claude OAuth授权URL
func GenerateClaudeOAuthURL(c *gin.Context) {
params, err := service.GenerateOAuthParams()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "生成OAuth授权URL失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "生成OAuth授权URL成功",
"data": params,
})
}
// ExchangeClaudeOAuthCode 交换Claude OAuth授权码
func ExchangeClaudeOAuthCode(c *gin.Context) {
var req ExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "请求参数错误: " + err.Error(),
})
return
}
// 解析授权码
cleanedCode, err := service.ParseAuthorizationCode(req.AuthorizationCode)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 交换token
tokenResult, err := service.ExchangeCode(cleanedCode, req.CodeVerifier, req.State, nil)
if err != nil {
common.SysError("Claude OAuth token exchange failed: " + err.Error())
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "授权码交换失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "授权码交换成功",
"data": tokenResult,
})
}

View File

@@ -47,7 +47,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }
if constant2.ErrorLogEnabled && err != nil { if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中 // 保存错误日志到mysql中
userId := c.GetInt("id") userId := c.GetInt("id")
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
@@ -62,6 +62,14 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
other["channel_id"] = channelId other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name") other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type") other["channel_type"] = c.GetInt("channel_type")
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other) model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
} }

View File

@@ -1,4 +1,4 @@
package gemini package dto
import ( import (
"encoding/json" "encoding/json"
@@ -56,7 +56,7 @@ type FunctionCall struct {
Arguments any `json:"args"` Arguments any `json:"args"`
} }
type FunctionResponse struct { type GeminiFunctionResponse struct {
Name string `json:"name"` Name string `json:"name"`
Response map[string]interface{} `json:"response"` Response map[string]interface{} `json:"response"`
} }
@@ -81,7 +81,7 @@ type GeminiPart struct {
Thought bool `json:"thought,omitempty"` Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"` FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"` ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"` CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`

1
go.mod
View File

@@ -87,7 +87,6 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.34.2 // indirect

2
go.sum
View File

@@ -231,8 +231,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=

View File

@@ -86,9 +86,6 @@ func main() {
// 数据看板 // 数据看板
go model.UpdateQuotaData() go model.UpdateQuotaData()
// Start Claude Code token refresh scheduler
service.StartClaudeTokenRefreshScheduler()
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil { if err != nil {

View File

@@ -269,6 +269,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel.ChannelInfo.IsMultiKey { if channel.ChannelInfo.IsMultiKey {
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index) common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
} else {
// 必须设置为 false否则在重试到单个 key 的时候会导致日志显示错误
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
} }
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
common.SetContextKey(c, constant.ContextKeyChannelKey, key) common.SetContextKey(c, constant.ContextKeyChannelKey, key)

View File

@@ -284,6 +284,21 @@ func FixAbility() (int, int, error) {
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试") return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
} }
defer fixLock.Unlock() defer fixLock.Unlock()
// truncate abilities table
if common.UsingSQLite {
err := DB.Exec("DELETE FROM abilities").Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
return 0, 0, err
}
} else {
err := DB.Exec("TRUNCATE TABLE abilities").Error
if err != nil {
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
return 0, 0, err
}
}
var channels []*Channel var channels []*Channel
// Find all channels // Find all channels
err := DB.Model(&Channel{}).Find(&channels).Error err := DB.Model(&Channel{}).Find(&channels).Error

View File

@@ -46,6 +46,9 @@ type Channel struct {
ParamOverride *string `json:"param_override" gorm:"type:text"` ParamOverride *string `json:"param_override" gorm:"type:text"`
// add after v0.8.5 // add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
// cache info
Keys []string `json:"-" gorm:"-"`
} }
type ChannelInfo struct { type ChannelInfo struct {
@@ -71,6 +74,9 @@ func (channel *Channel) getKeys() []string {
if channel.Key == "" { if channel.Key == "" {
return []string{} return []string{}
} }
if len(channel.Keys) > 0 {
return channel.Keys
}
trimmed := strings.TrimSpace(channel.Key) trimmed := strings.TrimSpace(channel.Key)
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios) // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
if strings.HasPrefix(trimmed, "[") { if strings.HasPrefix(trimmed, "[") {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/setting" "one-api/setting"
"sort" "sort"
"strings" "strings"
@@ -66,6 +67,20 @@ func InitChannelCache() {
channelSyncLock.Lock() channelSyncLock.Lock()
group2model2channels = newGroup2model2channels group2model2channels = newGroup2model2channels
//channelsIDM = newChannelId2channel
for i, channel := range newChannelId2channel {
if channel.ChannelInfo.IsMultiKey {
channel.Keys = channel.getKeys()
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
if oldChannel, ok := channelsIDM[i]; ok {
// 存在旧的渠道如果是多key且轮询保留轮询索引信息
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
}
}
}
}
}
channelsIDM = newChannelId2channel channelsIDM = newChannelId2channel
channelSyncLock.Unlock() channelSyncLock.Unlock()
common.SysLog("channels synced from database") common.SysLog("channels synced from database")
@@ -203,9 +218,6 @@ func CacheGetChannel(id int) (*Channel, error) {
if !ok { if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id) return nil, fmt.Errorf("渠道# %d已不存在", id)
} }
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return c, nil return c, nil
} }
@@ -224,9 +236,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !ok { if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id) return nil, fmt.Errorf("渠道# %d已不存在", id)
} }
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return &c.ChannelInfo, nil return &c.ChannelInfo, nil
} }

View File

@@ -26,6 +26,7 @@ type Adaptor interface {
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
} }
type TaskAdaptor interface { type TaskAdaptor interface {

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -22,6 +22,11 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model) c.Set("request_model", request.Model)
c.Set("converted_request", request) c.Set("converted_request", request)

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")
@@ -43,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
keyParts := strings.Split(info.ApiKey, "|") keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 0 || keyParts[0] == "" { if len(keyParts) == 0 || keyParts[0] == "" {
return errors.New("invalid API key: authorization token is required") return errors.New("invalid API key: authorization token is required")
} }
if len(keyParts) > 1 { if len(keyParts) > 1 {
if keyParts[1] != "" { if keyParts[1] != "" {
req.Set("appid", keyParts[1]) req.Set("appid", keyParts[1])
} }
} }
req.Set("Authorization", "Bearer "+keyParts[0]) req.Set("Authorization", "Bearer "+keyParts[0])
return nil return nil
} }

View File

@@ -24,6 +24,11 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil return request, nil
} }

View File

@@ -1,158 +0,0 @@
package claude_code
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
DefaultSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
// Use configured system prompt if available, otherwise use default
if info.ChannelSetting.SystemPrompt != "" {
request.System = info.ChannelSetting.SystemPrompt
} else {
request.System = DefaultSystemPrompt
}
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
// Parse accesstoken|refreshtoken format and use only the access token
accessToken := info.ApiKey
if strings.Contains(info.ApiKey, "|") {
parts := strings.Split(info.ApiKey, "|")
if len(parts) >= 1 {
accessToken = parts[0]
}
}
// Claude Code specific headers - force override
req.Set("Authorization", "Bearer "+accessToken)
// 只有在没有设置的情况下才设置 anthropic-version
if req.Get("anthropic-version") == "" {
req.Set("anthropic-version", "2023-06-01")
}
req.Set("content-type", "application/json")
// 只有在 user-agent 不包含 claude-cli 时才设置
userAgent := req.Get("user-agent")
if userAgent == "" || !strings.Contains(strings.ToLower(userAgent), "claude-cli") {
req.Set("user-agent", "claude-cli/1.0.61 (external, cli)")
}
// 只有在 anthropic-beta 不包含 claude-code 时才设置
anthropicBeta := req.Get("anthropic-beta")
if anthropicBeta == "" || !strings.Contains(strings.ToLower(anthropicBeta), "claude-code") {
req.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14")
}
// if Anthropic-Dangerous-Direct-Browser-Access
anthropicDangerousDirectBrowserAccess := req.Get("anthropic-dangerous-direct-browser-access")
if anthropicDangerousDirectBrowserAccess == "" {
req.Set("anthropic-dangerous-direct-browser-access", "true")
}
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeCompletion {
return claude.RequestOpenAI2ClaudeComplete(*request), nil
} else {
claudeRequest, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
// Use configured system prompt if available, otherwise use default
if info.ChannelSetting.SystemPrompt != "" {
claudeRequest.System = info.ChannelSetting.SystemPrompt
} else {
claudeRequest.System = DefaultSystemPrompt
}
return claudeRequest, nil
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = claude.ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = claude.ClaudeHandler(c, resp, a.RequestMode, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,14 +0,0 @@
package claude_code
var ModelList = []string{
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
"claude-sonnet-4-20250514",
"claude-sonnet-4-20250514-thinking",
"claude-opus-4-20250514",
"claude-opus-4-20250514-thinking",
}
var ChannelName = "claude_code"

View File

@@ -1,4 +0,0 @@
package claude_code
// Claude Code uses the same DTO structures as Claude since it's based on the same API
// This file is kept for consistency with the channel structure pattern

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
// ConvertAudioRequest implements channel.Adaptor. // ConvertAudioRequest implements channel.Adaptor.
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")

View File

@@ -19,6 +19,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -24,6 +24,11 @@ type Adaptor struct {
BotType int BotType int
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -20,6 +20,26 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
if len(request.Contents) > 0 {
for i, content := range request.Contents {
if i == 0 {
if request.Contents[0].Role == "" {
request.Contents[0].Role = "user"
}
}
for _, part := range content.Parts {
if part.FileData != nil {
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
part.FileData.MimeType = "video/webm"
}
}
}
}
}
return request, nil
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{} adaptor := openai.Adaptor{}
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req) oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
@@ -51,13 +71,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
// build gemini imagen request // build gemini imagen request
geminiRequest := GeminiImageRequest{ geminiRequest := dto.GeminiImageRequest{
Instances: []GeminiImageInstance{ Instances: []dto.GeminiImageInstance{
{ {
Prompt: request.Prompt, Prompt: request.Prompt,
}, },
}, },
Parameters: GeminiImageParameters{ Parameters: dto.GeminiImageParameters{
SampleCount: request.N, SampleCount: request.N,
AspectRatio: aspectRatio, AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult PersonGeneration: "allow_adult", // default allow adult
@@ -138,9 +158,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
} }
// only process the first input // only process the first input
geminiRequest := GeminiEmbeddingRequest{ geminiRequest := dto.GeminiEmbeddingRequest{
Content: GeminiChatContent{ Content: dto.GeminiChatContent{
Parts: []GeminiPart{ Parts: []dto.GeminiPart{
{ {
Text: inputs[0], Text: inputs[0],
}, },

View File

@@ -1,6 +1,7 @@
package gemini package gemini
import ( import (
"github.com/pkg/errors"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -28,7 +29,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
} }
// 解析为 Gemini 原生响应格式 // 解析为 Gemini 原生响应格式
var geminiResponse GeminiChatResponse var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse) err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -71,7 +72,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
responseText := strings.Builder{} responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse) err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
@@ -110,10 +111,14 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
if err != nil { if err != nil {
common.LogError(c, err.Error()) common.LogError(c, err.Error())
} }
info.SendResponseCount++
return true return true
}) })
if info.SendResponseCount == 0 {
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
}
if imageCount != 0 { if imageCount != 0 {
if usage.CompletionTokens == 0 { if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258 usage.CompletionTokens = imageCount * 258

View File

@@ -81,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int {
return budget return budget
} }
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) { func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
@@ -93,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
if len(parts) == 2 && parts[1] != "" { if len(parts) == 2 && parts[1] != "" {
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil { if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
clampedBudget := clampThinkingBudget(modelName, budgetTokens) clampedBudget := clampThinkingBudget(modelName, budgetTokens)
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(clampedBudget), ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true, IncludeThoughts: true,
} }
@@ -113,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
} }
if isUnsupported { if isUnsupported {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true, IncludeThoughts: true,
} }
} else { } else {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true, IncludeThoughts: true,
} }
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 { if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
@@ -128,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
} }
} else if strings.HasSuffix(modelName, "-nothinking") { } else if strings.HasSuffix(modelName, "-nothinking") {
if !isNew25Pro { if !isNew25Pro {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0), ThinkingBudget: common.GetPointer(0),
} }
} }
@@ -137,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
} }
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{ geminiRequest := dto.GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
GenerationConfig: GeminiChatGenerationConfig{ GenerationConfig: dto.GeminiChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.MaxTokens,
@@ -158,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
ThinkingAdaptor(&geminiRequest, info) ThinkingAdaptor(&geminiRequest, info)
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList { for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{ safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
Category: category, Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category), Threshold: model_setting.GetGeminiSafetySetting(category),
}) })
@@ -198,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
functions = append(functions, tool.Function) functions = append(functions, tool.Function)
} }
if codeExecution { if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
CodeExecution: make(map[string]string), CodeExecution: make(map[string]string),
}) })
} }
if googleSearch { if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
GoogleSearch: make(map[string]string), GoogleSearch: make(map[string]string),
}) })
} }
if len(functions) > 0 { if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{ geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
FunctionDeclarations: functions, FunctionDeclarations: functions,
}) })
} }
@@ -238,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
continue continue
} else if message.Role == "tool" || message.Role == "function" { } else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" { if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
Role: "user", Role: "user",
}) })
} }
@@ -265,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
} }
} }
functionResp := &FunctionResponse{ functionResp := &dto.GeminiFunctionResponse{
Name: name, Name: name,
Response: contentMap, Response: contentMap,
} }
*parts = append(*parts, GeminiPart{ *parts = append(*parts, dto.GeminiPart{
FunctionResponse: functionResp, FunctionResponse: functionResp,
}) })
continue continue
} }
var parts []GeminiPart var parts []dto.GeminiPart
content := GeminiChatContent{ content := dto.GeminiChatContent{
Role: message.Role, Role: message.Role,
} }
// isToolCall := false // isToolCall := false
@@ -290,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments) return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
} }
} }
toolCall := GeminiPart{ toolCall := dto.GeminiPart{
FunctionCall: &FunctionCall{ FunctionCall: &dto.FunctionCall{
FunctionName: call.Function.Name, FunctionName: call.Function.Name,
Arguments: args, Arguments: args,
}, },
@@ -308,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if part.Text == "" { if part.Text == "" {
continue continue
} }
parts = append(parts, GeminiPart{ parts = append(parts, dto.GeminiPart{
Text: part.Text, Text: part.Text,
}) })
} else if part.Type == dto.ContentTypeImageURL { } else if part.Type == dto.ContentTypeImageURL {
@@ -331,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList()) return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
} }
parts = append(parts, GeminiPart{ parts = append(parts, dto.GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &dto.GeminiInlineData{
MimeType: fileData.MimeType, // 使用原始的 MimeType因为大小写可能对API有意义 MimeType: fileData.MimeType, // 使用原始的 MimeType因为大小写可能对API有意义
Data: fileData.Base64Data, Data: fileData.Base64Data,
}, },
@@ -342,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil { if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error()) return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
} }
parts = append(parts, GeminiPart{ parts = append(parts, dto.GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &dto.GeminiInlineData{
MimeType: format, MimeType: format,
Data: base64String, Data: base64String,
}, },
@@ -357,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil { if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error()) return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
} }
parts = append(parts, GeminiPart{ parts = append(parts, dto.GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &dto.GeminiInlineData{
MimeType: format, MimeType: format,
Data: base64String, Data: base64String,
}, },
@@ -371,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil { if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error()) return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
} }
parts = append(parts, GeminiPart{ parts = append(parts, dto.GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &dto.GeminiInlineData{
MimeType: "audio/" + part.GetInputAudio().Format, MimeType: "audio/" + part.GetInputAudio().Format,
Data: base64String, Data: base64String,
}, },
@@ -392,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
} }
if len(system_content) > 0 { if len(system_content) > 0 {
geminiRequest.SystemInstructions = &GeminiChatContent{ geminiRequest.SystemInstructions = &dto.GeminiChatContent{
Parts: []GeminiPart{ Parts: []dto.GeminiPart{
{ {
Text: strings.Join(system_content, "\n"), Text: strings.Join(system_content, "\n"),
}, },
@@ -636,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data return data
} }
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte var argsBytes []byte
var err error var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok { if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -658,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
} }
} }
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse { func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: helper.GetResponseID(c), Id: helper.GetResponseID(c),
Object: "chat.completion", Object: "chat.completion",
@@ -725,10 +725,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
return &fullTextResponse return &fullTextResponse
} }
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) { func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)) choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false isStop := false
hasImage := false
for _, candidate := range geminiResponse.Candidates { for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" { if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true isStop = true
@@ -759,7 +758,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
if strings.HasPrefix(part.InlineData.MimeType, "image") { if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText) texts = append(texts, imgText)
hasImage = true
} }
} else if part.FunctionCall != nil { } else if part.FunctionCall != nil {
isTools = true isTools = true
@@ -796,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Choices = choices response.Choices = choices
return &response, isStop, hasImage return &response, isStop
} }
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error { func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
@@ -824,23 +822,31 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
// responseText := "" // responseText := ""
id := helper.GetResponseID(c) id := helper.GetResponseID(c)
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
responseText := strings.Builder{}
var usage = &dto.Usage{} var usage = &dto.Usage{}
var imageCount int var imageCount int
respCount := 0
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse) err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false
} }
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse) for _, candidate := range geminiResponse.Candidates {
if hasImage { for _, part := range candidate.Content.Parts {
imageCount++ if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
} }
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id response.Id = id
response.Created = createAt response.Created = createAt
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
@@ -858,7 +864,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
} }
} }
if respCount == 0 { if info.SendResponseCount == 0 {
// send first response // send first response
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)) err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
if err != nil { if err != nil {
@@ -873,11 +879,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
if isStop { if isStop {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)) _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
} }
respCount++
return true return true
}) })
if respCount == 0 { if info.SendResponseCount == 0 {
// 空补全,报错不计费 // 空补全,报错不计费
// empty response, throw an error // empty response, throw an error
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError) return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
@@ -892,6 +897,16 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
usage.PromptTokensDetails.TextTokens = usage.PromptTokens usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response) err := handleFinalStream(c, info, response)
if err != nil { if err != nil {
@@ -913,7 +928,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if common.DebugEnabled { if common.DebugEnabled {
println(string(responseBody)) println(string(responseBody))
} }
var geminiResponse GeminiChatResponse var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse) err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -959,7 +974,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
var geminiResponse GeminiEmbeddingResponse var geminiResponse dto.GeminiEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
@@ -1005,7 +1020,7 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
} }
_ = resp.Body.Close() _ = resp.Body.Close()
var geminiResponse GeminiImageResponse var geminiResponse dto.GeminiImageResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
@@ -13,11 +12,18 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/types" "one-api/types"
"github.com/gin-gonic/gin"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }

View File

@@ -19,6 +19,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
openaiAdaptor := openai.Adaptor{} openaiAdaptor := openai.Adaptor{}
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request) openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)

View File

@@ -34,6 +34,15 @@ type Adaptor struct {
ResponseFormat string ResponseFormat string
} }
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
// 使用 service.GeminiToOpenAIRequest 转换请求格式
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
if err != nil {
return nil, err
}
return a.ConvertOpenAIRequest(c, info, openaiRequest)
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
//if !strings.Contains(request.Model, "claude") { //if !strings.Contains(request.Model, "claude") {
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model) // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} }
if info.RelayMode == relayconstant.RelayModeRealtime { if info.RelayMode == relayconstant.RelayModeRealtime {

View File

@@ -2,6 +2,8 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"errors"
"net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
@@ -16,11 +18,14 @@ import (
// 辅助函数 // 辅助函数
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++ info.SendResponseCount++
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent) return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude: case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info) return handleClaudeFormat(c, data, info)
case relaycommon.RelayFormatGemini:
return handleGeminiFormat(c, data, info)
} }
return nil return nil
} }
@@ -41,6 +46,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return nil return nil
} }
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
return err
}
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
// 如果返回 nil表示没有实际内容跳过发送
if geminiResponse == nil {
return nil
}
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.LogError(c, "failed to marshal gemini response: "+err.Error())
return err
}
// send gemini format response
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error { func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -185,6 +220,37 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
for _, resp := range claudeResponses { for _, resp := range claudeResponses {
_ = helper.ClaudeData(c, *resp) _ = helper.ClaudeData(c, *resp)
} }
case relaycommon.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空finishReason 为 STOP 的响应
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
// 暂不知是否有程序会不兼容。
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
// openai 流响应开头的空数据
if geminiResponse == nil {
return
}
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.SysError("error marshalling gemini response: " + err.Error())
return
}
// 发送最终的 Gemini 响应
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
} }
} }

View File

@@ -223,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
responseBody = claudeRespStr responseBody = claudeRespStr
case relaycommon.RelayFormatGemini:
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
geminiRespStr, err := common.Marshal(geminiResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = geminiRespStr
} }
common.IOCopyBytesGracefully(c, resp, responseBody) common.IOCopyBytesGracefully(c, resp, responseBody)

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{} adaptor := openai.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req) return adaptor.ConvertClaudeRequest(c, info, req)

View File

@@ -25,6 +25,11 @@ type Adaptor struct {
Timestamp int64 Timestamp int64
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -44,6 +44,11 @@ type Adaptor struct {
AccountCredentials Credentials AccountCredentials Credentials
} }
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
geminiAdaptor := gemini.Adaptor{}
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok { if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v) c.Set("request_model", v)

View File

@@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
}) })
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) { func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId) var cacheKey string
if info.ChannelIsMultiKey {
cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
} else {
cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
}
val, err := Cache.Get(cacheKey) val, err := Cache.Get(cacheKey)
if err == nil { if err == nil {
return val.(string), nil return val.(string), nil

View File

@@ -23,6 +23,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -19,6 +19,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
//panic("implement me") //panic("implement me")

View File

@@ -17,6 +17,11 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest request *dto.GeneralOpenAIRequest
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -60,17 +60,19 @@ type ResponsesUsageInfo struct {
} }
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
TokenId int ChannelIsMultiKey bool // 是否多密钥
TokenKey string ChannelMultiKeyIndex int // 多密钥索引
UserId int TokenId int
UsingGroup string // 使用的分组 TokenKey string
UserGroup string // 用户所在分组 UserId int
TokenUnlimited bool UsingGroup string // 使用的分组
StartTime time.Time UserGroup string // 用户所在分组
FirstResponseTime time.Time TokenUnlimited bool
isFirstResponse bool StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
//SendLastReasoningResponse bool //SendLastReasoningResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
@@ -260,6 +262,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
IsFirstThinkingContent: true, IsFirstThinkingContent: true,
SendLastThinkingContent: false, SendLastThinkingContent: false,
}, },
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") { if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true info.IsPlayground = true

View File

@@ -20,8 +20,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) { func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
request := &gemini.GeminiChatRequest{} request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request) err := common.UnmarshalBodyReusable(c, request)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
// } // }
} }
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) { func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
var inputTexts []string var inputTexts []string
for _, content := range textRequest.Contents { for _, content := range textRequest.Contents {
for _, part := range content.Parts { for _, part := range content.Parts {
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
return sensitiveWords, err return sensitiveWords, err
} }
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量 // 计算输入 token 数量
var inputTexts []string var inputTexts []string
for _, content := range req.Contents { for _, content := range req.Contents {
@@ -78,9 +78,13 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
return inputTokens return inputTokens
} }
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool { func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget == 0 configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
if configBudget != nil && *configBudget == 0 {
// 如果思考预算为 0则认为是非思考请求
return true
}
} }
return false return false
} }
@@ -202,7 +206,12 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
} }
requestBody = bytes.NewReader(body) requestBody = bytes.NewReader(body)
} else { } else {
jsonData, err := common.Marshal(req) // 使用 ConvertGeminiRequest 转换请求格式
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
} }

View File

@@ -305,10 +305,10 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
} }
if userQuota <= 0 { if userQuota <= 0 {
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry()) return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
relayInfo.UserQuota = userQuota relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota { if userQuota > 100*preConsumedQuota {
@@ -332,7 +332,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota) err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry()) return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {

View File

@@ -9,7 +9,6 @@ import (
"one-api/relay/channel/baidu" "one-api/relay/channel/baidu"
"one-api/relay/channel/baidu_v2" "one-api/relay/channel/baidu_v2"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
"one-api/relay/channel/claude_code"
"one-api/relay/channel/cloudflare" "one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere" "one-api/relay/channel/cohere"
"one-api/relay/channel/coze" "one-api/relay/channel/coze"
@@ -99,8 +98,6 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &coze.Adaptor{} return &coze.Adaptor{}
case constant.APITypeJimeng: case constant.APITypeJimeng:
return &jimeng.Adaptor{} return &jimeng.Adaptor{}
case constant.APITypeClaudeCode:
return &claude_code.Adaptor{}
} }
return nil return nil
} }

View File

@@ -120,9 +120,6 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag) channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
channelRoute.GET("/tag/models", controller.GetTagModels) channelRoute.GET("/tag/models", controller.GetTagModels)
channelRoute.POST("/copy/:id", controller.CopyChannel) channelRoute.POST("/copy/:id", controller.CopyChannel)
// Claude OAuth路由
channelRoute.GET("/claude/oauth/url", controller.GenerateClaudeOAuthURL)
channelRoute.POST("/claude/oauth/exchange", controller.ExchangeClaudeOAuthCode)
} }
tokenRoute := apiRouter.Group("/token") tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth()) tokenRoute.Use(middleware.UserAuth())

View File

@@ -1,171 +0,0 @@
package service
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"time"
"golang.org/x/oauth2"
)
const (
// Default OAuth configuration values
DefaultAuthorizeURL = "https://claude.ai/oauth/authorize"
DefaultTokenURL = "https://console.anthropic.com/v1/oauth/token"
DefaultClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
DefaultRedirectURI = "https://console.anthropic.com/oauth/code/callback"
DefaultScopes = "user:inference"
)
// getOAuthValues returns OAuth configuration values from environment variables or defaults
func getOAuthValues() (authorizeURL, tokenURL, clientID, redirectURI, scopes string) {
authorizeURL = os.Getenv("CLAUDE_AUTHORIZE_URL")
if authorizeURL == "" {
authorizeURL = DefaultAuthorizeURL
}
tokenURL = os.Getenv("CLAUDE_TOKEN_URL")
if tokenURL == "" {
tokenURL = DefaultTokenURL
}
clientID = os.Getenv("CLAUDE_CLIENT_ID")
if clientID == "" {
clientID = DefaultClientID
}
redirectURI = os.Getenv("CLAUDE_REDIRECT_URI")
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
scopes = os.Getenv("CLAUDE_SCOPES")
if scopes == "" {
scopes = DefaultScopes
}
return
}
type OAuth2Credentials struct {
AuthURL string `json:"auth_url"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
CodeChallenge string `json:"code_challenge"`
}
// GetClaudeOAuthConfig returns the Claude OAuth2 configuration
func GetClaudeOAuthConfig() *oauth2.Config {
authorizeURL, tokenURL, clientID, redirectURI, scopes := getOAuthValues()
return &oauth2.Config{
ClientID: clientID,
RedirectURL: redirectURI,
Scopes: strings.Split(scopes, " "),
Endpoint: oauth2.Endpoint{
AuthURL: authorizeURL,
TokenURL: tokenURL,
},
}
}
// getOAuthConfig is kept for backward compatibility
func getOAuthConfig() *oauth2.Config {
return GetClaudeOAuthConfig()
}
// GenerateOAuthParams generates OAuth authorization URL and related parameters
func GenerateOAuthParams() (*OAuth2Credentials, error) {
config := getOAuthConfig()
// Generate PKCE parameters
codeVerifier := oauth2.GenerateVerifier()
state := oauth2.GenerateVerifier() // Reuse generator as state
// Generate authorization URL
authURL := config.AuthCodeURL(state,
oauth2.S256ChallengeOption(codeVerifier),
oauth2.SetAuthURLParam("code", "true"), // Claude-specific parameter
)
return &OAuth2Credentials{
AuthURL: authURL,
CodeVerifier: codeVerifier,
State: state,
CodeChallenge: oauth2.S256ChallengeFromVerifier(codeVerifier),
}, nil
}
// ExchangeCode
func ExchangeCode(authorizationCode, codeVerifier, state string, client *http.Client) (*oauth2.Token, error) {
config := getOAuthConfig()
if strings.Contains(authorizationCode, "#") {
parts := strings.Split(authorizationCode, "#")
if len(parts) > 0 {
authorizationCode = parts[0]
}
}
ctx := context.Background()
if client != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
}
token, err := config.Exchange(ctx, authorizationCode,
oauth2.VerifierOption(codeVerifier),
oauth2.SetAuthURLParam("state", state),
)
if err != nil {
return nil, fmt.Errorf("token exchange failed: %w", err)
}
return token, nil
}
func ParseAuthorizationCode(input string) (string, error) {
if input == "" {
return "", fmt.Errorf("please provide a valid authorization code")
}
// URLs are not allowed
if strings.Contains(input, "http") || strings.Contains(input, "https") {
return "", fmt.Errorf("authorization code cannot contain URLs")
}
return input, nil
}
// GetClaudeHTTPClient returns a configured HTTP client for Claude OAuth operations
func GetClaudeHTTPClient() *http.Client {
return &http.Client{
Timeout: 30 * time.Second,
}
}
// RefreshClaudeToken refreshes a Claude OAuth token using the refresh token
func RefreshClaudeToken(accessToken, refreshToken string) (*oauth2.Token, error) {
config := GetClaudeOAuthConfig()
// Create token from current values
currentToken := &oauth2.Token{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
}
ctx := context.Background()
if client := GetClaudeHTTPClient(); client != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
}
// Refresh the token
newToken, err := config.TokenSource(ctx, currentToken).Token()
if err != nil {
return nil, fmt.Errorf("failed to refresh Claude token: %w", err)
}
return newToken, nil
}

View File

@@ -1,94 +0,0 @@
package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/model"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
)
// StartClaudeTokenRefreshScheduler starts the scheduled token refresh for Claude Code channels
func StartClaudeTokenRefreshScheduler() {
ticker := time.NewTicker(5 * time.Minute)
gopool.Go(func() {
defer ticker.Stop()
for range ticker.C {
RefreshClaudeCodeTokens()
}
})
common.SysLog("Claude Code token refresh scheduler started (5 minute interval)")
}
// RefreshClaudeCodeTokens refreshes tokens for all active Claude Code channels
func RefreshClaudeCodeTokens() {
var channels []model.Channel
// Get all active Claude Code channels
err := model.DB.Where("type = ? AND status = ?", constant.ChannelTypeClaudeCode, common.ChannelStatusEnabled).Find(&channels).Error
if err != nil {
common.SysError("Failed to get Claude Code channels: " + err.Error())
return
}
refreshCount := 0
for _, channel := range channels {
if refreshTokenForChannel(&channel) {
refreshCount++
}
}
if refreshCount > 0 {
common.SysLog(fmt.Sprintf("Successfully refreshed %d Claude Code channel tokens", refreshCount))
}
}
// refreshTokenForChannel attempts to refresh token for a single channel
func refreshTokenForChannel(channel *model.Channel) bool {
// Parse key in format: accesstoken|refreshtoken
if channel.Key == "" || !strings.Contains(channel.Key, "|") {
common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
return false
}
parts := strings.Split(channel.Key, "|")
if len(parts) < 2 {
common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
return false
}
accessToken := parts[0]
refreshToken := parts[1]
if refreshToken == "" {
common.SysError(fmt.Sprintf("Channel %d has empty refresh token", channel.Id))
return false
}
// Check if token needs refresh (refresh 30 minutes before expiry)
// if !shouldRefreshToken(accessToken) {
// return false
// }
// Use shared refresh function
newToken, err := RefreshClaudeToken(accessToken, refreshToken)
if err != nil {
common.SysError(fmt.Sprintf("Failed to refresh token for channel %d: %s", channel.Id, err.Error()))
return false
}
// Update channel with new tokens
newKey := fmt.Sprintf("%s|%s", newToken.AccessToken, newToken.RefreshToken)
err = model.DB.Model(channel).Update("key", newKey).Error
if err != nil {
common.SysError(fmt.Sprintf("Failed to update channel %d with new token: %s", channel.Id, err.Error()))
return false
}
common.SysLog(fmt.Sprintf("Successfully refreshed token for Claude Code channel %d (%s)", channel.Id, channel.Name))
return true
}

View File

@@ -448,3 +448,353 @@ func toJSONString(v interface{}) string {
} }
return string(b) return string(b)
} }
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
openaiRequest := &dto.GeneralOpenAIRequest{
Model: info.UpstreamModelName,
Stream: info.IsStream,
}
// 转换 messages
var messages []dto.Message
for _, content := range geminiRequest.Contents {
message := dto.Message{
Role: convertGeminiRoleToOpenAI(content.Role),
}
// 处理 parts
var mediaContents []dto.MediaContent
var toolCalls []dto.ToolCallRequest
for _, part := range content.Parts {
if part.Text != "" {
mediaContent := dto.MediaContent{
Type: "text",
Text: part.Text,
}
mediaContents = append(mediaContents, mediaContent)
} else if part.InlineData != nil {
mediaContent := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
Detail: "auto",
MimeType: part.InlineData.MimeType,
},
}
mediaContents = append(mediaContents, mediaContent)
} else if part.FileData != nil {
mediaContent := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{
Url: part.FileData.FileUri,
Detail: "auto",
MimeType: part.FileData.MimeType,
},
}
mediaContents = append(mediaContents, mediaContent)
} else if part.FunctionCall != nil {
// 处理 Gemini 的工具调用
toolCall := dto.ToolCallRequest{
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
Type: "function",
Function: dto.FunctionRequest{
Name: part.FunctionCall.FunctionName,
Arguments: toJSONString(part.FunctionCall.Arguments),
},
}
toolCalls = append(toolCalls, toolCall)
} else if part.FunctionResponse != nil {
// 处理 Gemini 的工具响应,创建单独的 tool 消息
toolMessage := dto.Message{
Role: "tool",
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
}
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
messages = append(messages, toolMessage)
}
}
// 设置消息内容
if len(toolCalls) > 0 {
// 如果有工具调用,设置工具调用
message.SetToolCalls(toolCalls)
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
// 如果只有一个文本内容,直接设置字符串
message.Content = mediaContents[0].Text
} else if len(mediaContents) > 0 {
// 如果有多个内容或包含媒体,设置为数组
message.SetMediaContent(mediaContents)
}
// 只有当消息有内容或工具调用时才添加
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
messages = append(messages, message)
}
}
openaiRequest.Messages = messages
if geminiRequest.GenerationConfig.Temperature != nil {
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
}
if geminiRequest.GenerationConfig.TopP > 0 {
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
}
if geminiRequest.GenerationConfig.TopK > 0 {
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
}
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
}
// gemini stop sequences 最多 5 个openai stop 最多 4 个
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
}
if geminiRequest.GenerationConfig.CandidateCount > 0 {
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
}
// 转换工具调用
if len(geminiRequest.Tools) > 0 {
var tools []dto.ToolCallRequest
for _, tool := range geminiRequest.Tools {
if tool.FunctionDeclarations != nil {
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
if ok {
for _, function := range functionDeclarations {
openAITool := dto.ToolCallRequest{
Type: "function",
Function: dto.FunctionRequest{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
}
tools = append(tools, openAITool)
}
}
}
}
if len(tools) > 0 {
openaiRequest.Tools = tools
}
}
// gemini system instructions
if geminiRequest.SystemInstructions != nil {
// 将系统指令作为第一条消息插入
systemMessage := dto.Message{
Role: "system",
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
}
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
}
return openaiRequest, nil
}
func convertGeminiRoleToOpenAI(geminiRole string) string {
switch geminiRole {
case "user":
return "user"
case "model":
return "assistant"
case "function":
return "function"
default:
return "user"
}
}
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
var texts []string
for _, part := range parts {
if part.Text != "" {
texts = append(texts, part.Text)
}
}
return strings.Join(texts, "\n")
}
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
PromptFeedback: dto.GeminiChatPromptFeedback{
SafetyRatings: []dto.GeminiChatSafetyRating{},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: openAIResponse.PromptTokens,
CandidatesTokenCount: openAIResponse.CompletionTokens,
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
},
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),
SafetyRatings: []dto.GeminiChatSafetyRating{},
}
// 设置结束原因
var finishReason string
switch choice.FinishReason {
case "stop":
finishReason = "STOP"
case "length":
finishReason = "MAX_TOKENS"
case "content_filter":
finishReason = "SAFETY"
case "tool_calls":
finishReason = "STOP"
default:
finishReason = "STOP"
}
candidate.FinishReason = &finishReason
// 转换消息内容
content := dto.GeminiChatContent{
Role: "model",
Parts: make([]dto.GeminiPart, 0),
}
// 处理工具调用
toolCalls := choice.Message.ParseToolCalls()
if len(toolCalls) > 0 {
for _, toolCall := range toolCalls {
// 解析参数
var args map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
}
} else {
args = make(map[string]interface{})
}
part := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: toolCall.Function.Name,
Arguments: args,
},
}
content.Parts = append(content.Parts, part)
}
} else {
// 处理文本内容
textContent := choice.Message.StringContent()
if textContent != "" {
part := dto.GeminiPart{
Text: textContent,
}
content.Parts = append(content.Parts, part)
}
}
candidate.Content = content
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
}
return geminiResponse
}
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
// 检查是否有实际内容或结束标志
hasContent := false
hasFinishReason := false
for _, choice := range openAIResponse.Choices {
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
hasContent = true
}
if choice.FinishReason != nil {
hasFinishReason = true
}
}
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
if !hasContent && !hasFinishReason {
return nil
}
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
PromptFeedback: dto.GeminiChatPromptFeedback{
SafetyRatings: []dto.GeminiChatSafetyRating{},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: info.PromptTokens,
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
TotalTokenCount: info.PromptTokens,
},
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),
SafetyRatings: []dto.GeminiChatSafetyRating{},
}
// 设置结束原因
if choice.FinishReason != nil {
var finishReason string
switch *choice.FinishReason {
case "stop":
finishReason = "STOP"
case "length":
finishReason = "MAX_TOKENS"
case "content_filter":
finishReason = "SAFETY"
case "tool_calls":
finishReason = "STOP"
default:
finishReason = "STOP"
}
candidate.FinishReason = &finishReason
}
// 转换消息内容
content := dto.GeminiChatContent{
Role: "model",
Parts: make([]dto.GeminiPart, 0),
}
// 处理工具调用
if choice.Delta.ToolCalls != nil {
for _, toolCall := range choice.Delta.ToolCalls {
// 解析参数
var args map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
}
} else {
args = make(map[string]interface{})
}
part := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: toolCall.Function.Name,
Arguments: args,
},
}
content.Parts = append(content.Parts, part)
}
} else {
// 处理文本内容
textContent := choice.Delta.GetContentString()
if textContent != "" {
part := dto.GeminiPart{
Text: textContent,
}
content.Parts = append(content.Parts, part)
}
}
candidate.Content = content
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
}
return geminiResponse
}

View File

@@ -13,9 +13,6 @@ var AutomaticDisableKeywords = []string{
"The security token included in the request is invalid", "The security token included in the request is invalid",
"Operation not allowed", "Operation not allowed",
"Your account is not authorized", "Your account is not authorized",
// Claude Code
"Invalid bearer token",
"OAuth authentication is currently not allowed for this endpoint",
} }
func AutomaticDisableKeywordsToString() string { func AutomaticDisableKeywordsToString() string {

View File

@@ -76,12 +76,13 @@ const (
) )
type NewAPIError struct { type NewAPIError struct {
Err error Err error
RelayError any RelayError any
skipRetry bool skipRetry bool
errorType ErrorType recordErrorLog *bool
errorCode ErrorCode errorType ErrorType
StatusCode int errorCode ErrorCode
StatusCode int
} }
func (e *NewAPIError) GetErrorCode() ErrorCode { func (e *NewAPIError) GetErrorCode() ErrorCode {
@@ -278,3 +279,20 @@ func ErrOptionWithSkipRetry() NewAPIErrorOptions {
e.skipRetry = true e.skipRetry = true
} }
} }
func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
return func(e *NewAPIError) {
e.recordErrorLog = common.GetPointer(false)
}
}
func IsRecordErrorLog(e *NewAPIError) bool {
if e == nil {
return false
}
if e.recordErrorLog == nil {
// default to true if not set
return true
}
return *e.recordErrorLog
}

View File

@@ -65,7 +65,8 @@ const JSONEditor = ({
const keyCount = Object.keys(parsed).length; const keyCount = Object.keys(parsed).length;
return keyCount > 10 ? 'manual' : 'visual'; return keyCount > 10 ? 'manual' : 'visual';
} catch (error) { } catch (error) {
return 'visual'; // JSON无效时默认显示手动编辑模式
return 'manual';
} }
} }
return 'visual'; return 'visual';
@@ -201,6 +202,18 @@ const JSONEditor = ({
// 渲染键值对编辑器 // 渲染键值对编辑器
const renderKeyValueEditor = () => { const renderKeyValueEditor = () => {
if (typeof jsonData !== 'object' || jsonData === null) {
return (
<div className="text-center py-6 px-4">
<div className="text-gray-400 mb-2">
<IconCode size={32} />
</div>
<Text type="tertiary" className="text-gray-500 text-sm">
{t('无效的JSON数据请检查格式')}
</Text>
</div>
);
}
const entries = Object.entries(jsonData); const entries = Object.entries(jsonData);
return ( return (

View File

@@ -17,6 +17,8 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com For commercial licensing, please contact support@quantumnous.com
*/ */
import React, { useEffect, useState, useRef, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { import {
API, API,
showError, showError,
@@ -24,42 +26,38 @@ import {
showSuccess, showSuccess,
verifyJSON, verifyJSON,
} from '../../../../helpers'; } from '../../../../helpers';
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
import { CHANNEL_OPTIONS } from '../../../../constants';
import { import {
Avatar,
Banner,
Button,
Card,
Checkbox,
Col,
Form,
Highlight,
ImagePreview,
Input,
Modal,
Row,
SideSheet, SideSheet,
Space, Space,
Spin, Spin,
Tag, Button,
Typography, Typography,
Checkbox,
Banner,
Modal,
ImagePreview,
Card,
Tag,
Avatar,
Form,
Row,
Col,
Highlight,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers'; import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers';
import ModelSelectModal from './ModelSelectModal'; import ModelSelectModal from './ModelSelectModal';
import JSONEditor from '../../../common/JSONEditor'; import JSONEditor from '../../../common/JSONEditor';
import { CHANNEL_OPTIONS, CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT } from '../../../../constants';
import { import {
IconBolt,
IconClose,
IconCode,
IconGlobe,
IconSave, IconSave,
IconClose,
IconServer, IconServer,
IconSetting, IconSetting,
IconCode,
IconGlobe,
IconBolt,
} from '@douyinfe/semi-icons'; } from '@douyinfe/semi-icons';
import React, { useEffect, useMemo, useRef, useState } from 'react';
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
import { useTranslation } from 'react-i18next';
const { Text, Title } = Typography; const { Text, Title } = Typography;
@@ -95,8 +93,6 @@ function type2secretPrompt(type) {
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API则直接输ApiKey'; return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API则直接输ApiKey';
case 51: case 51:
return '按照如下格式输入: Access Key ID|Secret Access Key'; return '按照如下格式输入: Access Key ID|Secret Access Key';
case 53:
return '按照如下格式输入AccessToken|RefreshToken';
default: default:
return '请输入渠道对应的鉴权密钥'; return '请输入渠道对应的鉴权密钥';
} }
@@ -149,10 +145,6 @@ const EditChannelModal = (props) => {
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const [modalImageUrl, setModalImageUrl] = useState(''); const [modalImageUrl, setModalImageUrl] = useState('');
const [isModalOpenurl, setIsModalOpenurl] = useState(false); const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const [showOAuthModal, setShowOAuthModal] = useState(false);
const [authorizationCode, setAuthorizationCode] = useState('');
const [oauthParams, setOauthParams] = useState(null);
const [isExchangingCode, setIsExchangingCode] = useState(false);
const [modelModalVisible, setModelModalVisible] = useState(false); const [modelModalVisible, setModelModalVisible] = useState(false);
const [fetchedModels, setFetchedModels] = useState([]); const [fetchedModels, setFetchedModels] = useState([]);
const formApiRef = useRef(null); const formApiRef = useRef(null);
@@ -162,6 +154,7 @@ const EditChannelModal = (props) => {
const [isMultiKeyChannel, setIsMultiKeyChannel] = useState(false); const [isMultiKeyChannel, setIsMultiKeyChannel] = useState(false);
const [channelSearchValue, setChannelSearchValue] = useState(''); const [channelSearchValue, setChannelSearchValue] = useState('');
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式 const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
const [keyMode, setKeyMode] = useState('append'); // 密钥模式replace覆盖或 append追加
// 渠道额外设置状态 // 渠道额外设置状态
const [channelSettings, setChannelSettings] = useState({ const [channelSettings, setChannelSettings] = useState({
force_format: false, force_format: false,
@@ -361,24 +354,6 @@ const EditChannelModal = (props) => {
data.system_prompt = ''; data.system_prompt = '';
} }
// 特殊处理Claude Code渠道的密钥拆分和系统提示词
if (data.type === 53) {
// 拆分密钥
if (data.key) {
const keyParts = data.key.split('|');
if (keyParts.length === 2) {
data.access_token = keyParts[0];
data.refresh_token = keyParts[1];
} else {
// 如果没有 | 分隔符表示只有access token
data.access_token = data.key;
data.refresh_token = '';
}
}
// 强制设置固定系统提示词
data.system_prompt = CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT;
}
setInputs(data); setInputs(data);
if (formApiRef.current) { if (formApiRef.current) {
formApiRef.current.setValues(data); formApiRef.current.setValues(data);
@@ -502,72 +477,6 @@ const EditChannelModal = (props) => {
} }
}; };
// 生成OAuth授权URL
const handleGenerateOAuth = async () => {
try {
setLoading(true);
const res = await API.get('/api/channel/claude/oauth/url');
if (res.data.success) {
setOauthParams(res.data.data);
setShowOAuthModal(true);
showSuccess(t('OAuth授权URL生成成功'));
} else {
showError(res.data.message || t('生成OAuth授权URL失败'));
}
} catch (error) {
showError(t('生成OAuth授权URL失败') + error.message);
} finally {
setLoading(false);
}
};
// 交换授权码
const handleExchangeCode = async () => {
if (!authorizationCode.trim()) {
showError(t('请输入授权码'));
return;
}
if (!oauthParams) {
showError(t('OAuth参数丢失请重新生成'));
return;
}
try {
setIsExchangingCode(true);
const res = await API.post('/api/channel/claude/oauth/exchange', {
authorization_code: authorizationCode,
code_verifier: oauthParams.code_verifier,
state: oauthParams.state,
});
if (res.data.success) {
const tokenData = res.data.data;
// 自动填充access token和refresh token
handleInputChange('access_token', tokenData.access_token);
handleInputChange('refresh_token', tokenData.refresh_token);
handleInputChange('key', `${tokenData.access_token}|${tokenData.refresh_token}`);
// 更新表单字段
if (formApiRef.current) {
formApiRef.current.setValue('access_token', tokenData.access_token);
formApiRef.current.setValue('refresh_token', tokenData.refresh_token);
}
setShowOAuthModal(false);
setAuthorizationCode('');
setOauthParams(null);
showSuccess(t('授权码交换成功已自动填充tokens'));
} else {
showError(res.data.message || t('授权码交换失败'));
}
} catch (error) {
showError(t('授权码交换失败:') + error.message);
} finally {
setIsExchangingCode(false);
}
};
useEffect(() => { useEffect(() => {
const modelMap = new Map(); const modelMap = new Map();
@@ -652,6 +561,12 @@ const EditChannelModal = (props) => {
pass_through_body_enabled: false, pass_through_body_enabled: false,
system_prompt: '', system_prompt: '',
}); });
// 重置密钥模式状态
setKeyMode('append');
// 清空表单中的key_mode字段
if (formApiRef.current) {
formApiRef.current.setValue('key_mode', undefined);
}
} }
}, [props.visible, channelId]); }, [props.visible, channelId]);
@@ -817,6 +732,7 @@ const EditChannelModal = (props) => {
res = await API.put(`/api/channel/`, { res = await API.put(`/api/channel/`, {
...localInputs, ...localInputs,
id: parseInt(channelId), id: parseInt(channelId),
key_mode: isMultiKeyChannel ? keyMode : undefined, // 只在多key模式下传递
}); });
} else { } else {
res = await API.post(`/api/channel/`, { res = await API.post(`/api/channel/`, {
@@ -879,55 +795,59 @@ const EditChannelModal = (props) => {
const batchAllowed = !isEdit || isMultiKeyChannel; const batchAllowed = !isEdit || isMultiKeyChannel;
const batchExtra = batchAllowed ? ( const batchExtra = batchAllowed ? (
<Space> <Space>
<Checkbox {!isEdit && (
disabled={isEdit || inputs.type === 53} <Checkbox
checked={batch} disabled={isEdit}
onChange={(e) => { checked={batch}
const checked = e.target.checked; onChange={(e) => {
const checked = e.target.checked;
if (!checked && vertexFileList.length > 1) { if (!checked && vertexFileList.length > 1) {
Modal.confirm({ Modal.confirm({
title: t('切换为单密钥模式'), title: t('切换为单密钥模式'),
content: t('将仅保留第一个密钥文件,其余文件将被移除,是否继续?'), content: t('将仅保留第一个密钥文件,其余文件将被移除,是否继续?'),
onOk: () => { onOk: () => {
const firstFile = vertexFileList[0]; const firstFile = vertexFileList[0];
const firstKey = vertexKeys[0] ? [vertexKeys[0]] : []; const firstKey = vertexKeys[0] ? [vertexKeys[0]] : [];
setVertexFileList([firstFile]); setVertexFileList([firstFile]);
setVertexKeys(firstKey); setVertexKeys(firstKey);
formApiRef.current?.setValue('vertex_files', [firstFile]); formApiRef.current?.setValue('vertex_files', [firstFile]);
setInputs((prev) => ({ ...prev, vertex_files: [firstFile] })); setInputs((prev) => ({ ...prev, vertex_files: [firstFile] }));
setBatch(false); setBatch(false);
setMultiToSingle(false); setMultiToSingle(false);
setMultiKeyMode('random'); setMultiKeyMode('random');
}, },
onCancel: () => { onCancel: () => {
setBatch(true); setBatch(true);
}, },
centered: true, centered: true,
}); });
return; return;
}
setBatch(checked);
if (!checked) {
setMultiToSingle(false);
setMultiKeyMode('random');
} else {
// 批量模式下禁用手动输入,并清空手动输入的内容
setUseManualInput(false);
if (inputs.type === 41) {
// 清空手动输入的密钥内容
if (formApiRef.current) {
formApiRef.current.setValue('key', '');
}
handleInputChange('key', '');
} }
}
}} setBatch(checked);
>{t('批量创建')}</Checkbox> if (!checked) {
setMultiToSingle(false);
setMultiKeyMode('random');
} else {
// 批量模式下禁用手动输入,并清空手动输入的内容
setUseManualInput(false);
if (inputs.type === 41) {
// 清空手动输入的密钥内容
if (formApiRef.current) {
formApiRef.current.setValue('key', '');
}
handleInputChange('key', '');
}
}
}}
>
{t('批量创建')}
</Checkbox>
)}
{batch && ( {batch && (
<Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => { <Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
setMultiToSingle(prev => !prev); setMultiToSingle(prev => !prev);
@@ -1124,7 +1044,16 @@ const EditChannelModal = (props) => {
autosize autosize
autoComplete='new-password' autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)} onChange={(value) => handleInputChange('key', value)}
extraText={batchExtra} extraText={
<div className="flex items-center gap-2">
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
<Text type="warning" size="small">
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
</Text>
)}
{batchExtra}
</div>
}
showClear showClear
/> />
) )
@@ -1191,6 +1120,11 @@ const EditChannelModal = (props) => {
<Text type="tertiary" size="small"> <Text type="tertiary" size="small">
{t('请输入完整的 JSON 格式密钥内容')} {t('请输入完整的 JSON 格式密钥内容')}
</Text> </Text>
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
<Text type="warning" size="small">
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
</Text>
)}
{batchExtra} {batchExtra}
</div> </div>
} }
@@ -1216,49 +1150,6 @@ const EditChannelModal = (props) => {
/> />
)} )}
</> </>
) : inputs.type === 53 ? (
<>
<Form.Input
field='access_token'
label={isEdit ? t('Access Token编辑模式下保存的密钥不会显示') : t('Access Token')}
placeholder={t('sk-ant-xxx')}
rules={isEdit ? [] : [{ required: true, message: t('请输入Access Token') }]}
autoComplete='new-password'
onChange={(value) => {
handleInputChange('access_token', value);
// 同时更新key字段格式为access_token|refresh_token
const refreshToken = inputs.refresh_token || '';
handleInputChange('key', `${value}|${refreshToken}`);
}}
suffix={
<Button
size="small"
type="primary"
theme="light"
onClick={handleGenerateOAuth}
>
{t('生成OAuth授权码')}
</Button>
}
extraText={batchExtra}
showClear
/>
<Form.Input
field='refresh_token'
label={isEdit ? t('Refresh Token编辑模式下保存的密钥不会显示') : t('Refresh Token')}
placeholder={t('sk-ant-xxx可选')}
rules={[]}
autoComplete='new-password'
onChange={(value) => {
handleInputChange('refresh_token', value);
// 同时更新key字段格式为access_token|refresh_token
const accessToken = inputs.access_token || '';
handleInputChange('key', `${accessToken}|${value}`);
}}
extraText={batchExtra}
showClear
/>
</>
) : ( ) : (
<Form.Input <Form.Input
field='key' field='key'
@@ -1267,13 +1158,44 @@ const EditChannelModal = (props) => {
rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]} rules={isEdit ? [] : [{ required: true, message: t('请输入密钥') }]}
autoComplete='new-password' autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)} onChange={(value) => handleInputChange('key', value)}
extraText={batchExtra} extraText={
<div className="flex items-center gap-2">
{isEdit && isMultiKeyChannel && keyMode === 'append' && (
<Text type="warning" size="small">
{t('追加模式:新密钥将添加到现有密钥列表的末尾')}
</Text>
)}
{batchExtra}
</div>
}
showClear showClear
/> />
)} )}
</> </>
)} )}
{isEdit && isMultiKeyChannel && (
<Form.Select
field='key_mode'
label={t('密钥更新模式')}
placeholder={t('请选择密钥更新模式')}
optionList={[
{ label: t('追加到现有密钥'), value: 'append' },
{ label: t('覆盖现有密钥'), value: 'replace' },
]}
style={{ width: '100%' }}
value={keyMode}
onChange={(value) => setKeyMode(value)}
extraText={
<Text type="tertiary" size="small">
{keyMode === 'replace'
? t('覆盖模式:将完全替换现有的所有密钥')
: t('追加模式:将新密钥添加到现有密钥列表末尾')
}
</Text>
}
/>
)}
{batch && multiToSingle && ( {batch && multiToSingle && (
<> <>
<Form.Select <Form.Select
@@ -1767,19 +1689,11 @@ const EditChannelModal = (props) => {
<Form.TextArea <Form.TextArea
field='system_prompt' field='system_prompt'
label={t('系统提示词')} label={t('系统提示词')}
placeholder={inputs.type === 53 ? CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT : t('输入系统提示词,用户的系统提示词将优先于此设置')} placeholder={t('输入系统提示词,用户的系统提示词将优先于此设置')}
onChange={(value) => { onChange={(value) => handleChannelSettingsChange('system_prompt', value)}
if (inputs.type === 53) {
// Claude Code渠道系统提示词固定不允许修改
return;
}
handleChannelSettingsChange('system_prompt', value);
}}
disabled={inputs.type === 53}
value={inputs.type === 53 ? CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT : undefined}
autosize autosize
showClear={inputs.type !== 53} showClear
extraText={inputs.type === 53 ? t('Claude Code渠道系统提示词固定为官方CLI身份不可修改') : t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')} extraText={t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')}
/> />
</Card> </Card>
</div> </div>
@@ -1803,68 +1717,6 @@ const EditChannelModal = (props) => {
}} }}
onCancel={() => setModelModalVisible(false)} onCancel={() => setModelModalVisible(false)}
/> />
{/* OAuth Authorization Modal */}
<Modal
title={t('生成Claude Code OAuth授权码')}
visible={showOAuthModal}
onCancel={() => {
setShowOAuthModal(false);
setAuthorizationCode('');
setOauthParams(null);
}}
onOk={handleExchangeCode}
okText={isExchangingCode ? t('交换中...') : t('确认')}
cancelText={t('取消')}
confirmLoading={isExchangingCode}
width={600}
>
<div className="space-y-4">
<div>
<Text className="text-sm font-medium mb-2 block">{t('请访问以下授权地址:')}</Text>
<div className="p-3 bg-gray-50 rounded-lg border">
<Text
link
underline
className="text-sm font-mono break-all cursor-pointer text-blue-600 hover:text-blue-800"
onClick={() => {
if (oauthParams?.auth_url) {
window.open(oauthParams.auth_url, '_blank');
}
}}
>
{oauthParams?.auth_url || t('正在生成授权地址...')}
</Text>
<div className="mt-2">
<Text
copyable={{ content: oauthParams?.auth_url }}
type="tertiary"
size="small"
>
{t('复制链接')}
</Text>
</div>
</div>
</div>
<div>
<Text className="text-sm font-medium mb-2 block">{t('授权后,请将获得的授权码粘贴到下方:')}</Text>
<Input
value={authorizationCode}
onChange={setAuthorizationCode}
placeholder={t('请输入授权码')}
showClear
style={{ width: '100%' }}
/>
</div>
<Banner
type="info"
description={t('获得授权码后系统将自动换取access token和refresh token并填充到表单中。')}
className="!rounded-lg"
/>
</div>
</Modal>
</> </>
); );
}; };

View File

@@ -159,14 +159,6 @@ export const CHANNEL_OPTIONS = [
color: 'purple', color: 'purple',
label: 'Vidu', label: 'Vidu',
}, },
{
value: 53,
color: 'indigo',
label: 'Claude Code',
},
]; ];
export const MODEL_TABLE_PAGE_SIZE = 10; export const MODEL_TABLE_PAGE_SIZE = 10;
// Claude Code 相关常量
export const CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT = "You are Claude Code, Anthropic's official CLI for Claude.";

View File

@@ -353,7 +353,6 @@ export function getChannelIcon(channelType) {
return <Ollama size={iconSize} />; return <Ollama size={iconSize} />;
case 14: // Anthropic Claude case 14: // Anthropic Claude
case 33: // AWS Claude case 33: // AWS Claude
case 53: // Claude Code
return <Claude.Color size={iconSize} />; return <Claude.Color size={iconSize} />;
case 41: // Vertex AI case 41: // Vertex AI
return <Gemini.Color size={iconSize} />; return <Gemini.Color size={iconSize} />;