Merge branch 'alpha' into refactor/model-pricing

This commit is contained in:
t0ng7u
2025-07-31 22:28:59 +08:00
30 changed files with 1547 additions and 107 deletions

View File

@@ -65,6 +65,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = constant.APITypeCoze apiType = constant.APITypeCoze
case constant.ChannelTypeJimeng: case constant.ChannelTypeJimeng:
apiType = constant.APITypeJimeng apiType = constant.APITypeJimeng
case constant.ChannelTypeClaudeCode:
apiType = constant.APITypeClaudeCode
} }
if apiType == -1 { if apiType == -1 {
return constant.APITypeOpenAI, false return constant.APITypeOpenAI, false

View File

@@ -31,5 +31,6 @@ const (
APITypeXai APITypeXai
APITypeCoze APITypeCoze
APITypeJimeng APITypeJimeng
APITypeClaudeCode
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )

View File

@@ -50,6 +50,7 @@ const (
ChannelTypeKling = 50 ChannelTypeKling = 50
ChannelTypeJimeng = 51 ChannelTypeJimeng = 51
ChannelTypeVidu = 52 ChannelTypeVidu = 52
ChannelTypeClaudeCode = 53
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
@@ -108,4 +109,5 @@ var ChannelBaseURLs = []string{
"https://api.klingai.com", //50 "https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51 "https://visual.volcengineapi.com", //51
"https://api.vidu.cn", //52 "https://api.vidu.cn", //52
"https://api.anthropic.com", //53
} }

View File

@@ -332,8 +332,11 @@ func TestChannel(c *gin.Context) {
} }
channel, err := model.CacheGetChannel(channelId) channel, err := model.CacheGetChannel(channelId)
if err != nil { if err != nil {
common.ApiError(c, err) channel, err = model.GetChannelById(channelId, true)
return if err != nil {
common.ApiError(c, err)
return
}
} }
//defer func() { //defer func() {
// if channel.ChannelInfo.IsMultiKey { // if channel.ChannelInfo.IsMultiKey {

View File

@@ -0,0 +1,73 @@
package controller
import (
"net/http"
"one-api/common"
"one-api/service"
"github.com/gin-gonic/gin"
)
// ExchangeCodeRequest 授权码交换请求
type ExchangeCodeRequest struct {
AuthorizationCode string `json:"authorization_code" binding:"required"`
CodeVerifier string `json:"code_verifier" binding:"required"`
State string `json:"state" binding:"required"`
}
// GenerateClaudeOAuthURL 生成Claude OAuth授权URL
func GenerateClaudeOAuthURL(c *gin.Context) {
params, err := service.GenerateOAuthParams()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "生成OAuth授权URL失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "生成OAuth授权URL成功",
"data": params,
})
}
// ExchangeClaudeOAuthCode 交换Claude OAuth授权码
func ExchangeClaudeOAuthCode(c *gin.Context) {
var req ExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "请求参数错误: " + err.Error(),
})
return
}
// 解析授权码
cleanedCode, err := service.ParseAuthorizationCode(req.AuthorizationCode)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 交换token
tokenResult, err := service.ExchangeCode(cleanedCode, req.CodeVerifier, req.State, nil)
if err != nil {
common.SysError("Claude OAuth token exchange failed: " + err.Error())
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "授权码交换失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "授权码交换成功",
"data": tokenResult,
})
}

View File

@@ -2,6 +2,7 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"fmt"
"one-api/common" "one-api/common"
"one-api/types" "one-api/types"
) )
@@ -284,14 +285,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
return mediaContent return mediaContent
} }
type ClaudeError struct {
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
}
type ClaudeErrorWithStatusCode struct { type ClaudeErrorWithStatusCode struct {
Error ClaudeError `json:"error"` Error types.ClaudeError `json:"error"`
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
LocalError bool LocalError bool
} }
@@ -303,7 +299,7 @@ type ClaudeResponse struct {
Completion string `json:"completion,omitempty"` Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"` StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Error *types.ClaudeError `json:"error,omitempty"` Error any `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
@@ -324,6 +320,42 @@ func (c *ClaudeResponse) GetIndex() int {
return *c.Index return *c.Index
} }
// GetClaudeError 从动态错误类型中提取ClaudeError结构
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
if c.Error == nil {
return nil
}
switch err := c.Error.(type) {
case types.ClaudeError:
return &err
case *types.ClaudeError:
return err
case map[string]interface{}:
// 处理从JSON解析来的map结构
claudeErr := &types.ClaudeError{}
if errType, ok := err["type"].(string); ok {
claudeErr.Type = errType
}
if errMsg, ok := err["message"].(string); ok {
claudeErr.Message = errMsg
}
return claudeErr
case string:
// 处理简单字符串错误
return &types.ClaudeError{
Type: "error",
Message: err,
}
default:
// 未知类型,尝试转换为字符串
return &types.ClaudeError{
Type: "unknown_error",
Message: fmt.Sprintf("%v", err),
}
}
}
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"`

View File

@@ -2,12 +2,18 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"fmt"
"one-api/types" "one-api/types"
) )
type SimpleResponse struct { type SimpleResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
Error *OpenAIError `json:"error"` Error any `json:"error"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(s.Error)
} }
type TextResponse struct { type TextResponse struct {
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
Object string `json:"object"` Object string `json:"object"`
Created any `json:"created"` Created any `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`
Error *types.OpenAIError `json:"error,omitempty"` Error any `json:"error,omitempty"`
Usage `json:"usage"` Usage `json:"usage"`
} }
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
type OpenAIEmbeddingResponseItem struct { type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"` Object string `json:"object"`
Index int `json:"index"` Index int `json:"index"`
@@ -217,7 +228,7 @@ type OpenAIResponsesResponse struct {
Object string `json:"object"` Object string `json:"object"`
CreatedAt int `json:"created_at"` CreatedAt int `json:"created_at"`
Status string `json:"status"` Status string `json:"status"`
Error *types.OpenAIError `json:"error,omitempty"` Error any `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"` Instructions string `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"` MaxOutputTokens int `json:"max_output_tokens"`
@@ -237,6 +248,11 @@ type OpenAIResponsesResponse struct {
Metadata json.RawMessage `json:"metadata"` Metadata json.RawMessage `json:"metadata"`
} }
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
type IncompleteDetails struct { type IncompleteDetails struct {
Reasoning string `json:"reasoning"` Reasoning string `json:"reasoning"`
} }
@@ -276,3 +292,45 @@ type ResponsesStreamResponse struct {
Delta string `json:"delta,omitempty"` Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"` Item *ResponsesOutput `json:"item,omitempty"`
} }
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func GetOpenAIError(errorField any) *types.OpenAIError {
if errorField == nil {
return nil
}
switch err := errorField.(type) {
case types.OpenAIError:
return &err
case *types.OpenAIError:
return err
case map[string]interface{}:
// 处理从JSON解析来的map结构
openaiErr := &types.OpenAIError{}
if errType, ok := err["type"].(string); ok {
openaiErr.Type = errType
}
if errMsg, ok := err["message"].(string); ok {
openaiErr.Message = errMsg
}
if errParam, ok := err["param"].(string); ok {
openaiErr.Param = errParam
}
if errCode, ok := err["code"]; ok {
openaiErr.Code = errCode
}
return openaiErr
case string:
// 处理简单字符串错误
return &types.OpenAIError{
Type: "error",
Message: err,
}
default:
// 未知类型,尝试转换为字符串
return &types.OpenAIError{
Type: "unknown_error",
Message: fmt.Sprintf("%v", err),
}
}
}

1
go.mod
View File

@@ -87,6 +87,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.34.2 // indirect

2
go.sum
View File

@@ -231,6 +231,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=

View File

@@ -86,6 +86,9 @@ func main() {
// 数据看板 // 数据看板
go model.UpdateQuotaData() go model.UpdateQuotaData()
// Start Claude Code token refresh scheduler
service.StartClaudeTokenRefreshScheduler()
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil { if err != nil {

View File

@@ -239,6 +239,20 @@ func CacheUpdateChannelStatus(id int, status int) {
if channel, ok := channelsIDM[id]; ok { if channel, ok := channelsIDM[id]; ok {
channel.Status = status channel.Status = status
} }
if status != common.ChannelStatusEnabled {
// delete the channel from group2model2channels
for group, model2channels := range group2model2channels {
for model, channels := range model2channels {
for i, channelId := range channels {
if channelId == id {
// remove the channel from the slice
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
break
}
}
}
}
}
} }
func CacheUpdateChannel(channel *Channel) { func CacheUpdateChannel(channel *Channel) {

View File

@@ -612,8 +612,8 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody) return types.NewError(err, types.ErrorCodeBadResponseBody)
} }
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" { if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError) return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
} }
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == relaycommon.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
@@ -704,8 +704,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody) return types.NewError(err, types.ErrorCodeBadResponseBody)
} }
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" { if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError) return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
} }
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)

View File

@@ -0,0 +1,158 @@
package claude_code
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
DefaultSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
// Use configured system prompt if available, otherwise use default
if info.ChannelSetting.SystemPrompt != "" {
request.System = info.ChannelSetting.SystemPrompt
} else {
request.System = DefaultSystemPrompt
}
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
// Parse accesstoken|refreshtoken format and use only the access token
accessToken := info.ApiKey
if strings.Contains(info.ApiKey, "|") {
parts := strings.Split(info.ApiKey, "|")
if len(parts) >= 1 {
accessToken = parts[0]
}
}
// Claude Code specific headers - force override
req.Set("Authorization", "Bearer "+accessToken)
// 只有在没有设置的情况下才设置 anthropic-version
if req.Get("anthropic-version") == "" {
req.Set("anthropic-version", "2023-06-01")
}
req.Set("content-type", "application/json")
// 只有在 user-agent 不包含 claude-cli 时才设置
userAgent := req.Get("user-agent")
if userAgent == "" || !strings.Contains(strings.ToLower(userAgent), "claude-cli") {
req.Set("user-agent", "claude-cli/1.0.61 (external, cli)")
}
// 只有在 anthropic-beta 不包含 claude-code 时才设置
anthropicBeta := req.Get("anthropic-beta")
if anthropicBeta == "" || !strings.Contains(strings.ToLower(anthropicBeta), "claude-code") {
req.Set("anthropic-beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14")
}
// if Anthropic-Dangerous-Direct-Browser-Access
anthropicDangerousDirectBrowserAccess := req.Get("anthropic-dangerous-direct-browser-access")
if anthropicDangerousDirectBrowserAccess == "" {
req.Set("anthropic-dangerous-direct-browser-access", "true")
}
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeCompletion {
return claude.RequestOpenAI2ClaudeComplete(*request), nil
} else {
claudeRequest, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
// Use configured system prompt if available, otherwise use default
if info.ChannelSetting.SystemPrompt != "" {
claudeRequest.System = info.ChannelSetting.SystemPrompt
} else {
claudeRequest.System = DefaultSystemPrompt
}
return claudeRequest, nil
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = claude.ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = claude.ClaudeHandler(c, resp, a.RequestMode, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,14 @@
package claude_code
var ModelList = []string{
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
"claude-sonnet-4-20250514",
"claude-sonnet-4-20250514-thinking",
"claude-opus-4-20250514",
"claude-opus-4-20250514-thinking",
}
var ChannelName = "claude_code"

View File

@@ -0,0 +1,4 @@
package claude_code
// Claude Code uses the same DTO structures as Claude since it's based on the same API
// This file is kept for consistency with the channel structure pattern

View File

@@ -1,6 +1,9 @@
package gemini package gemini
import "encoding/json" import (
"encoding/json"
"one-api/common"
)
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"` Contents []GeminiChatContent `json:"contents"`
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
MimeTypeSnake string `json:"mime_type"` MimeTypeSnake string `json:"mime_type"`
} }
if err := json.Unmarshal(data, &aux); err != nil { if err := common.Unmarshal(data, &aux); err != nil {
return err return err
} }
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
} }
if err := json.Unmarshal(data, &aux); err != nil { if err := common.Unmarshal(data, &aux); err != nil {
return err return err
} }

View File

@@ -184,8 +184,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
} }
forceFormat := false forceFormat := false

View File

@@ -28,8 +28,8 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if responsesResponse.Error != nil { if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
} }
// 写入新的 response body // 写入新的 response body

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel/baidu" "one-api/relay/channel/baidu"
"one-api/relay/channel/baidu_v2" "one-api/relay/channel/baidu_v2"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
"one-api/relay/channel/claude_code"
"one-api/relay/channel/cloudflare" "one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere" "one-api/relay/channel/cohere"
"one-api/relay/channel/coze" "one-api/relay/channel/coze"
@@ -98,6 +99,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &coze.Adaptor{} return &coze.Adaptor{}
case constant.APITypeJimeng: case constant.APITypeJimeng:
return &jimeng.Adaptor{} return &jimeng.Adaptor{}
case constant.APITypeClaudeCode:
return &claude_code.Adaptor{}
} }
return nil return nil
} }

View File

@@ -120,6 +120,9 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag) channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
channelRoute.GET("/tag/models", controller.GetTagModels) channelRoute.GET("/tag/models", controller.GetTagModels)
channelRoute.POST("/copy/:id", controller.CopyChannel) channelRoute.POST("/copy/:id", controller.CopyChannel)
// Claude OAuth路由
channelRoute.GET("/claude/oauth/url", controller.GenerateClaudeOAuthURL)
channelRoute.POST("/claude/oauth/exchange", controller.ExchangeClaudeOAuthCode)
} }
tokenRoute := apiRouter.Group("/token") tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth()) tokenRoute.Use(middleware.UserAuth())

171
service/claude_oauth.go Normal file
View File

@@ -0,0 +1,171 @@
package service
import (
"context"
"fmt"
"net/http"
"os"
"strings"
"time"
"golang.org/x/oauth2"
)
const (
// Default OAuth configuration values
DefaultAuthorizeURL = "https://claude.ai/oauth/authorize"
DefaultTokenURL = "https://console.anthropic.com/v1/oauth/token"
DefaultClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
DefaultRedirectURI = "https://console.anthropic.com/oauth/code/callback"
DefaultScopes = "user:inference"
)
// getOAuthValues returns OAuth configuration values from environment variables or defaults
func getOAuthValues() (authorizeURL, tokenURL, clientID, redirectURI, scopes string) {
authorizeURL = os.Getenv("CLAUDE_AUTHORIZE_URL")
if authorizeURL == "" {
authorizeURL = DefaultAuthorizeURL
}
tokenURL = os.Getenv("CLAUDE_TOKEN_URL")
if tokenURL == "" {
tokenURL = DefaultTokenURL
}
clientID = os.Getenv("CLAUDE_CLIENT_ID")
if clientID == "" {
clientID = DefaultClientID
}
redirectURI = os.Getenv("CLAUDE_REDIRECT_URI")
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
scopes = os.Getenv("CLAUDE_SCOPES")
if scopes == "" {
scopes = DefaultScopes
}
return
}
type OAuth2Credentials struct {
AuthURL string `json:"auth_url"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
CodeChallenge string `json:"code_challenge"`
}
// GetClaudeOAuthConfig returns the Claude OAuth2 configuration
func GetClaudeOAuthConfig() *oauth2.Config {
authorizeURL, tokenURL, clientID, redirectURI, scopes := getOAuthValues()
return &oauth2.Config{
ClientID: clientID,
RedirectURL: redirectURI,
Scopes: strings.Split(scopes, " "),
Endpoint: oauth2.Endpoint{
AuthURL: authorizeURL,
TokenURL: tokenURL,
},
}
}
// getOAuthConfig is kept for backward compatibility
func getOAuthConfig() *oauth2.Config {
return GetClaudeOAuthConfig()
}
// GenerateOAuthParams generates OAuth authorization URL and related parameters
func GenerateOAuthParams() (*OAuth2Credentials, error) {
config := getOAuthConfig()
// Generate PKCE parameters
codeVerifier := oauth2.GenerateVerifier()
state := oauth2.GenerateVerifier() // Reuse generator as state
// Generate authorization URL
authURL := config.AuthCodeURL(state,
oauth2.S256ChallengeOption(codeVerifier),
oauth2.SetAuthURLParam("code", "true"), // Claude-specific parameter
)
return &OAuth2Credentials{
AuthURL: authURL,
CodeVerifier: codeVerifier,
State: state,
CodeChallenge: oauth2.S256ChallengeFromVerifier(codeVerifier),
}, nil
}
// ExchangeCode
func ExchangeCode(authorizationCode, codeVerifier, state string, client *http.Client) (*oauth2.Token, error) {
config := getOAuthConfig()
if strings.Contains(authorizationCode, "#") {
parts := strings.Split(authorizationCode, "#")
if len(parts) > 0 {
authorizationCode = parts[0]
}
}
ctx := context.Background()
if client != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
}
token, err := config.Exchange(ctx, authorizationCode,
oauth2.VerifierOption(codeVerifier),
oauth2.SetAuthURLParam("state", state),
)
if err != nil {
return nil, fmt.Errorf("token exchange failed: %w", err)
}
return token, nil
}
func ParseAuthorizationCode(input string) (string, error) {
if input == "" {
return "", fmt.Errorf("please provide a valid authorization code")
}
// URLs are not allowed
if strings.Contains(input, "http") || strings.Contains(input, "https") {
return "", fmt.Errorf("authorization code cannot contain URLs")
}
return input, nil
}
// GetClaudeHTTPClient returns a configured HTTP client for Claude OAuth operations
func GetClaudeHTTPClient() *http.Client {
return &http.Client{
Timeout: 30 * time.Second,
}
}
// RefreshClaudeToken refreshes a Claude OAuth token using the refresh token
func RefreshClaudeToken(accessToken, refreshToken string) (*oauth2.Token, error) {
config := GetClaudeOAuthConfig()
// Create token from current values
currentToken := &oauth2.Token{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
}
ctx := context.Background()
if client := GetClaudeHTTPClient(); client != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
}
// Refresh the token
newToken, err := config.TokenSource(ctx, currentToken).Token()
if err != nil {
return nil, fmt.Errorf("failed to refresh Claude token: %w", err)
}
return newToken, nil
}

View File

@@ -0,0 +1,94 @@
package service
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/model"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
)
// StartClaudeTokenRefreshScheduler starts the scheduled token refresh for Claude Code channels
func StartClaudeTokenRefreshScheduler() {
ticker := time.NewTicker(5 * time.Minute)
gopool.Go(func() {
defer ticker.Stop()
for range ticker.C {
RefreshClaudeCodeTokens()
}
})
common.SysLog("Claude Code token refresh scheduler started (5 minute interval)")
}
// RefreshClaudeCodeTokens refreshes tokens for all active Claude Code channels
func RefreshClaudeCodeTokens() {
var channels []model.Channel
// Get all active Claude Code channels
err := model.DB.Where("type = ? AND status = ?", constant.ChannelTypeClaudeCode, common.ChannelStatusEnabled).Find(&channels).Error
if err != nil {
common.SysError("Failed to get Claude Code channels: " + err.Error())
return
}
refreshCount := 0
for _, channel := range channels {
if refreshTokenForChannel(&channel) {
refreshCount++
}
}
if refreshCount > 0 {
common.SysLog(fmt.Sprintf("Successfully refreshed %d Claude Code channel tokens", refreshCount))
}
}
// refreshTokenForChannel attempts to refresh token for a single channel
func refreshTokenForChannel(channel *model.Channel) bool {
// Parse key in format: accesstoken|refreshtoken
if channel.Key == "" || !strings.Contains(channel.Key, "|") {
common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
return false
}
parts := strings.Split(channel.Key, "|")
if len(parts) < 2 {
common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
return false
}
accessToken := parts[0]
refreshToken := parts[1]
if refreshToken == "" {
common.SysError(fmt.Sprintf("Channel %d has empty refresh token", channel.Id))
return false
}
// Check if token needs refresh (refresh 30 minutes before expiry)
// if !shouldRefreshToken(accessToken) {
// return false
// }
// Use shared refresh function
newToken, err := RefreshClaudeToken(accessToken, refreshToken)
if err != nil {
common.SysError(fmt.Sprintf("Failed to refresh token for channel %d: %s", channel.Id, err.Error()))
return false
}
// Update channel with new tokens
newKey := fmt.Sprintf("%s|%s", newToken.AccessToken, newToken.RefreshToken)
err = model.DB.Model(channel).Update("key", newKey).Error
if err != nil {
common.SysError(fmt.Sprintf("Failed to update channel %d with new token: %s", channel.Id, err.Error()))
return false
}
common.SysLog(fmt.Sprintf("Successfully refreshed token for Claude Code channel %d (%s)", channel.Id, channel.Name))
return true
}

View File

@@ -188,28 +188,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
return &openAIRequest, nil return &openAIRequest, nil
} }
func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
claudeError := dto.ClaudeError{
Type: "new_api_error",
Message: openAIError.Error.Message,
}
return &dto.ClaudeErrorWithStatusCode{
Error: claudeError,
StatusCode: openAIError.StatusCode,
}
}
func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
openAIError := dto.OpenAIError{
Message: claudeError.Error.Message,
Type: "new_api_error",
}
return &dto.OpenAIErrorWithStatusCode{
Error: openAIError,
StatusCode: claudeError.StatusCode,
}
}
func generateStopBlock(index int) *dto.ClaudeResponse { func generateStopBlock(index int) *dto.ClaudeResponse {
return &dto.ClaudeResponse{ return &dto.ClaudeResponse{
Type: "content_block_stop", Type: "content_block_stop",

View File

@@ -62,7 +62,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
text = "请求上游地址失败" text = "请求上游地址失败"
} }
} }
claudeError := dto.ClaudeError{ claudeError := types.ClaudeError{
Message: text, Message: text,
Type: "new_api_error", Type: "new_api_error",
} }

View File

@@ -13,6 +13,9 @@ var AutomaticDisableKeywords = []string{
"The security token included in the request is invalid", "The security token included in the request is invalid",
"Operation not allowed", "Operation not allowed",
"Your account is not authorized", "Your account is not authorized",
// Claude Code
"Invalid bearer token",
"OAuth authentication is currently not allowed for this endpoint",
} }
func AutomaticDisableKeywordsToString() string { func AutomaticDisableKeywordsToString() string {

View File

@@ -16,8 +16,8 @@ type OpenAIError struct {
} }
type ClaudeError struct { type ClaudeError struct {
Message string `json:"message,omitempty"`
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
} }
type ErrorType string type ErrorType string

View File

@@ -0,0 +1,609 @@
import React, { useState, useEffect, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
Space,
Button,
Form,
Card,
Typography,
Banner,
Row,
Col,
InputNumber,
Switch,
Select,
Input,
} from '@douyinfe/semi-ui';
import {
IconCode,
IconEdit,
IconPlus,
IconDelete,
IconSetting,
} from '@douyinfe/semi-icons';
const { Text } = Typography;
const JSONEditor = ({
value = '',
onChange,
field,
label,
placeholder,
extraText,
showClear = true,
template,
templateLabel,
editorType = 'keyValue', // keyValue, object, region
autosize = true,
rules = [],
formApi = null,
...props
}) => {
const { t } = useTranslation();
// 初始化JSON数据
const [jsonData, setJsonData] = useState(() => {
// 初始化时解析JSON数据
if (value && value.trim()) {
try {
const parsed = JSON.parse(value);
return parsed;
} catch (error) {
return {};
}
}
return {};
});
// 根据键数量决定默认编辑模式
const [editMode, setEditMode] = useState(() => {
// 如果初始JSON数据的键数量大于10个则默认使用手动模式
if (value && value.trim()) {
try {
const parsed = JSON.parse(value);
const keyCount = Object.keys(parsed).length;
return keyCount > 10 ? 'manual' : 'visual';
} catch (error) {
return 'visual';
}
}
return 'visual';
});
const [jsonError, setJsonError] = useState('');
// 数据同步 - 当value变化时总是更新jsonData如果JSON有效
useEffect(() => {
try {
const parsed = value && value.trim() ? JSON.parse(value) : {};
setJsonData(parsed);
setJsonError('');
} catch (error) {
console.log('JSON解析失败:', error.message);
setJsonError(error.message);
// JSON格式错误时不更新jsonData
}
}, [value]);
// 处理可视化编辑的数据变化
const handleVisualChange = useCallback((newData) => {
setJsonData(newData);
setJsonError('');
const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2);
// 通过formApi设置值如果提供的话
if (formApi && field) {
formApi.setValue(field, jsonString);
}
onChange?.(jsonString);
}, [onChange, formApi, field]);
// 处理手动编辑的数据变化
const handleManualChange = useCallback((newValue) => {
onChange?.(newValue);
// 验证JSON格式
if (newValue && newValue.trim()) {
try {
const parsed = JSON.parse(newValue);
setJsonError('');
// 预先准备可视化数据,但不立即应用
// 这样切换到可视化模式时数据已经准备好了
} catch (error) {
setJsonError(error.message);
}
} else {
setJsonError('');
}
}, [onChange]);
// 切换编辑模式
const toggleEditMode = useCallback(() => {
if (editMode === 'visual') {
// 从可视化模式切换到手动模式
setEditMode('manual');
} else {
// 从手动模式切换到可视化模式需要验证JSON
try {
const parsed = value && value.trim() ? JSON.parse(value) : {};
setJsonData(parsed);
setJsonError('');
setEditMode('visual');
} catch (error) {
setJsonError(error.message);
// JSON格式错误时不切换模式
return;
}
}
}, [editMode, value]);
// 添加键值对
const addKeyValue = useCallback(() => {
const newData = { ...jsonData };
const keys = Object.keys(newData);
let newKey = 'key';
let counter = 1;
while (newData.hasOwnProperty(newKey)) {
newKey = `key${counter}`;
counter++;
}
newData[newKey] = '';
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 删除键值对
const removeKeyValue = useCallback((keyToRemove) => {
const newData = { ...jsonData };
delete newData[keyToRemove];
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 更新键名
const updateKey = useCallback((oldKey, newKey) => {
if (oldKey === newKey) return;
const newData = { ...jsonData };
const value = newData[oldKey];
delete newData[oldKey];
newData[newKey] = value;
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 更新值
const updateValue = useCallback((key, newValue) => {
const newData = { ...jsonData };
newData[key] = newValue;
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 填入模板
const fillTemplate = useCallback(() => {
if (template) {
const templateString = JSON.stringify(template, null, 2);
// 通过formApi设置值如果提供的话
if (formApi && field) {
formApi.setValue(field, templateString);
}
// 无论哪种模式都要更新值
onChange?.(templateString);
// 如果是可视化模式同时更新jsonData
if (editMode === 'visual') {
setJsonData(template);
}
// 清除错误状态
setJsonError('');
}
}, [template, onChange, editMode, formApi, field]);
// 渲染键值对编辑器
const renderKeyValueEditor = () => {
const entries = Object.entries(jsonData);
return (
<div className="space-y-1">
{entries.length === 0 && (
<div className="text-center py-6 px-4">
<div className="text-gray-400 mb-2">
<IconCode size={32} />
</div>
<Text type="tertiary" className="text-gray-500 text-sm">
{t('暂无数据,点击下方按钮添加键值对')}
</Text>
</div>
)}
{entries.map(([key, value], index) => (
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
<Row gutter={12} align="middle">
<Col span={10}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('键名')}</Text>
<Input
placeholder={t('键名')}
value={key}
onChange={(newKey) => updateKey(key, newKey)}
size="small"
/>
</div>
</Col>
<Col span={11}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('值')}</Text>
<Input
placeholder={t('值')}
value={value}
onChange={(newValue) => updateValue(key, newValue)}
size="small"
/>
</div>
</Col>
<Col span={3}>
<div className="flex justify-center pt-4">
<Button
icon={<IconDelete />}
type="danger"
theme="borderless"
size="small"
onClick={() => removeKeyValue(key)}
className="hover:bg-red-50"
/>
</div>
</Col>
</Row>
</Card>
))}
<div className="flex justify-center pt-1">
<Button
icon={<IconPlus />}
onClick={addKeyValue}
size="small"
theme="solid"
type="primary"
className="shadow-sm hover:shadow-md transition-shadow px-4"
>
{t('添加键值对')}
</Button>
</div>
</div>
);
};
// 渲染对象编辑器用于复杂JSON
const renderObjectEditor = () => {
const entries = Object.entries(jsonData);
return (
<div className="space-y-1">
{entries.length === 0 && (
<div className="text-center py-6 px-4">
<div className="text-gray-400 mb-2">
<IconSetting size={32} />
</div>
<Text type="tertiary" className="text-gray-500 text-sm">
{t('暂无参数,点击下方按钮添加请求参数')}
</Text>
</div>
)}
{entries.map(([key, value], index) => (
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
<Row gutter={12} align="middle">
<Col span={8}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('参数名')}</Text>
<Input
placeholder={t('参数名')}
value={key}
onChange={(newKey) => updateKey(key, newKey)}
size="small"
/>
</div>
</Col>
<Col span={13}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('参数值')} ({typeof value})</Text>
{renderValueInput(key, value)}
</div>
</Col>
<Col span={3}>
<div className="flex justify-center pt-4">
<Button
icon={<IconDelete />}
type="danger"
theme="borderless"
size="small"
onClick={() => removeKeyValue(key)}
className="hover:bg-red-50"
/>
</div>
</Col>
</Row>
</Card>
))}
<div className="flex justify-center pt-1">
<Button
icon={<IconPlus />}
onClick={addKeyValue}
size="small"
theme="solid"
type="primary"
className="shadow-sm hover:shadow-md transition-shadow px-4"
>
{t('添加参数')}
</Button>
</div>
</div>
);
};
// 渲染参数值输入控件
const renderValueInput = (key, value) => {
const valueType = typeof value;
if (valueType === 'boolean') {
return (
<div className="flex items-center">
<Switch
checked={value}
onChange={(newValue) => updateValue(key, newValue)}
size="small"
/>
<Text type="tertiary" size="small" className="ml-2">
{value ? t('true') : t('false')}
</Text>
</div>
);
}
if (valueType === 'number') {
return (
<InputNumber
value={value}
onChange={(newValue) => updateValue(key, newValue)}
size="small"
style={{ width: '100%' }}
step={key === 'temperature' ? 0.1 : 1}
precision={key === 'temperature' ? 2 : 0}
placeholder={t('输入数字')}
/>
);
}
// 字符串类型或其他类型
return (
<Input
placeholder={t('参数值')}
value={String(value)}
onChange={(newValue) => {
// 尝试转换为适当的类型
let convertedValue = newValue;
if (newValue === 'true') convertedValue = true;
else if (newValue === 'false') convertedValue = false;
else if (!isNaN(newValue) && newValue !== '' && newValue !== '0') {
convertedValue = Number(newValue);
}
updateValue(key, convertedValue);
}}
size="small"
/>
);
};
// 渲染区域编辑器(特殊格式)
const renderRegionEditor = () => {
const entries = Object.entries(jsonData);
const defaultEntry = entries.find(([key]) => key === 'default');
const modelEntries = entries.filter(([key]) => key !== 'default');
return (
<div className="space-y-1">
{/* 默认区域 */}
<Card className="!p-2 !border-blue-200 !bg-blue-50">
<div className="flex items-center mb-1">
<Text strong size="small" className="text-blue-700">{t('默认区域')}</Text>
</div>
<Input
placeholder={t('默认区域,如: us-central1')}
value={defaultEntry ? defaultEntry[1] : ''}
onChange={(value) => updateValue('default', value)}
size="small"
/>
</Card>
{/* 模型专用区域 */}
<div className="space-y-1">
<Text strong size="small">{t('模型专用区域')}</Text>
{modelEntries.map(([modelName, region], index) => (
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
<Row gutter={12} align="middle">
<Col span={10}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('模型名称')}</Text>
<Input
placeholder={t('模型名称')}
value={modelName}
onChange={(newKey) => updateKey(modelName, newKey)}
size="small"
/>
</div>
</Col>
<Col span={11}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('区域')}</Text>
<Input
placeholder={t('区域')}
value={region}
onChange={(newValue) => updateValue(modelName, newValue)}
size="small"
/>
</div>
</Col>
<Col span={3}>
<div className="flex justify-center pt-4">
<Button
icon={<IconDelete />}
type="danger"
theme="borderless"
size="small"
onClick={() => removeKeyValue(modelName)}
className="hover:bg-red-50"
/>
</div>
</Col>
</Row>
</Card>
))}
<div className="flex justify-center pt-1">
<Button
icon={<IconPlus />}
onClick={addKeyValue}
size="small"
theme="solid"
type="primary"
className="shadow-sm hover:shadow-md transition-shadow px-4"
>
{t('添加模型区域')}
</Button>
</div>
</div>
</div>
);
};
// 渲染可视化编辑器
const renderVisualEditor = () => {
switch (editorType) {
case 'region':
return renderRegionEditor();
case 'object':
return renderObjectEditor();
case 'keyValue':
default:
return renderKeyValueEditor();
}
};
const hasJsonError = jsonError && jsonError.trim() !== '';
return (
<div className="space-y-1">
{/* Label统一显示在上方 */}
{label && (
<div className="flex items-center">
<Text className="text-sm font-medium text-gray-900">{label}</Text>
</div>
)}
{/* 编辑模式切换 */}
<div className="flex items-center justify-between p-2 bg-gray-50 rounded-md">
<div className="flex items-center gap-2">
{editMode === 'visual' && (
<Text type="tertiary" size="small" className="bg-blue-100 text-blue-700 px-2 py-0.5 rounded text-xs">
{t('可视化模式')}
</Text>
)}
{editMode === 'manual' && (
<Text type="tertiary" size="small" className="bg-green-100 text-green-700 px-2 py-0.5 rounded text-xs">
{t('手动编辑模式')}
</Text>
)}
</div>
<div className="flex items-center gap-2">
{template && templateLabel && (
<Button
size="small"
type="tertiary"
onClick={fillTemplate}
className="!text-semi-color-primary hover:bg-blue-50 text-xs"
>
{templateLabel}
</Button>
)}
<Space size="tight">
<Button
size="small"
type={editMode === 'visual' ? 'primary' : 'tertiary'}
icon={<IconEdit />}
onClick={toggleEditMode}
disabled={editMode === 'manual' && hasJsonError}
className={editMode === 'visual' ? 'shadow-sm' : ''}
>
{t('可视化')}
</Button>
<Button
size="small"
type={editMode === 'manual' ? 'primary' : 'tertiary'}
icon={<IconCode />}
onClick={toggleEditMode}
className={editMode === 'manual' ? 'shadow-sm' : ''}
>
{t('手动编辑')}
</Button>
</Space>
</div>
</div>
{/* JSON错误提示 */}
{hasJsonError && (
<Banner
type="danger"
description={`JSON 格式错误: ${jsonError}`}
className="!rounded-md text-sm"
/>
)}
{/* 编辑器内容 */}
{editMode === 'visual' ? (
<div>
<Card className="!p-3 !border-gray-200 !shadow-sm !rounded-md bg-white">
{renderVisualEditor()}
</Card>
{/* 可视化模式下的额外文本显示在下方 */}
{extraText && (
<div className="text-xs text-gray-600 mt-0.5">
{extraText}
</div>
)}
{/* 隐藏的Form字段用于验证和数据绑定 */}
<Form.Input
field={field}
value={value}
rules={rules}
style={{ display: 'none' }}
noLabel={true}
{...props}
/>
</div>
) : (
<Form.TextArea
field={field}
placeholder={placeholder}
value={value}
onChange={handleManualChange}
showClear={showClear}
rows={Math.max(8, value ? value.split('\n').length : 8)}
rules={rules}
noLabel={true}
{...props}
/>
)}
{/* 额外文本在手动编辑模式下显示 */}
{extraText && editMode === 'manual' && (
<div className="text-xs text-gray-600">
{extraText}
</div>
)}
</div>
);
};
export default JSONEditor;

View File

@@ -17,8 +17,6 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com For commercial licensing, please contact support@quantumnous.com
*/ */
import React, { useEffect, useState, useRef, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { import {
API, API,
showError, showError,
@@ -26,37 +24,42 @@ import {
showSuccess, showSuccess,
verifyJSON, verifyJSON,
} from '../../../../helpers'; } from '../../../../helpers';
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
import { CHANNEL_OPTIONS } from '../../../../constants';
import { import {
Avatar,
Banner,
Button,
Card,
Checkbox,
Col,
Form,
Highlight,
ImagePreview,
Input,
Modal,
Row,
SideSheet, SideSheet,
Space, Space,
Spin, Spin,
Button,
Typography,
Checkbox,
Banner,
Modal,
ImagePreview,
Card,
Tag, Tag,
Avatar, Typography,
Form,
Row,
Col,
Highlight,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers'; import { getChannelModels, copy, getChannelIcon, getModelCategories, selectFilter } from '../../../../helpers';
import ModelSelectModal from './ModelSelectModal'; import ModelSelectModal from './ModelSelectModal';
import JSONEditor from '../../../common/JSONEditor';
import { CHANNEL_OPTIONS, CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT } from '../../../../constants';
import { import {
IconSave, IconBolt,
IconClose, IconClose,
IconServer,
IconSetting,
IconCode, IconCode,
IconGlobe, IconGlobe,
IconBolt, IconSave,
IconServer,
IconSetting,
} from '@douyinfe/semi-icons'; } from '@douyinfe/semi-icons';
import React, { useEffect, useMemo, useRef, useState } from 'react';
import { useIsMobile } from '../../../../hooks/common/useIsMobile.js';
import { useTranslation } from 'react-i18next';
const { Text, Title } = Typography; const { Text, Title } = Typography;
@@ -69,7 +72,9 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
}; };
const REGION_EXAMPLE = { const REGION_EXAMPLE = {
default: 'us-central1', "default": 'global',
"gemini-1.5-pro-002": "europe-west2",
"gemini-1.5-flash-002": "europe-west2",
'claude-3-5-sonnet-20240620': 'europe-west1', 'claude-3-5-sonnet-20240620': 'europe-west1',
}; };
@@ -90,6 +95,8 @@ function type2secretPrompt(type) {
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API则直接输ApiKey'; return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API则直接输ApiKey';
case 51: case 51:
return '按照如下格式输入: Access Key ID|Secret Access Key'; return '按照如下格式输入: Access Key ID|Secret Access Key';
case 53:
return '按照如下格式输入AccessToken|RefreshToken';
default: default:
return '请输入渠道对应的鉴权密钥'; return '请输入渠道对应的鉴权密钥';
} }
@@ -142,6 +149,10 @@ const EditChannelModal = (props) => {
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const [modalImageUrl, setModalImageUrl] = useState(''); const [modalImageUrl, setModalImageUrl] = useState('');
const [isModalOpenurl, setIsModalOpenurl] = useState(false); const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const [showOAuthModal, setShowOAuthModal] = useState(false);
const [authorizationCode, setAuthorizationCode] = useState('');
const [oauthParams, setOauthParams] = useState(null);
const [isExchangingCode, setIsExchangingCode] = useState(false);
const [modelModalVisible, setModelModalVisible] = useState(false); const [modelModalVisible, setModelModalVisible] = useState(false);
const [fetchedModels, setFetchedModels] = useState([]); const [fetchedModels, setFetchedModels] = useState([]);
const formApiRef = useRef(null); const formApiRef = useRef(null);
@@ -350,6 +361,24 @@ const EditChannelModal = (props) => {
data.system_prompt = ''; data.system_prompt = '';
} }
// 特殊处理Claude Code渠道的密钥拆分和系统提示词
if (data.type === 53) {
// 拆分密钥
if (data.key) {
const keyParts = data.key.split('|');
if (keyParts.length === 2) {
data.access_token = keyParts[0];
data.refresh_token = keyParts[1];
} else {
// 如果没有 | 分隔符表示只有access token
data.access_token = data.key;
data.refresh_token = '';
}
}
// 强制设置固定系统提示词
data.system_prompt = CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT;
}
setInputs(data); setInputs(data);
if (formApiRef.current) { if (formApiRef.current) {
formApiRef.current.setValues(data); formApiRef.current.setValues(data);
@@ -473,6 +502,72 @@ const EditChannelModal = (props) => {
} }
}; };
// 生成OAuth授权URL
const handleGenerateOAuth = async () => {
try {
setLoading(true);
const res = await API.get('/api/channel/claude/oauth/url');
if (res.data.success) {
setOauthParams(res.data.data);
setShowOAuthModal(true);
showSuccess(t('OAuth授权URL生成成功'));
} else {
showError(res.data.message || t('生成OAuth授权URL失败'));
}
} catch (error) {
showError(t('生成OAuth授权URL失败') + error.message);
} finally {
setLoading(false);
}
};
// 交换授权码
const handleExchangeCode = async () => {
if (!authorizationCode.trim()) {
showError(t('请输入授权码'));
return;
}
if (!oauthParams) {
showError(t('OAuth参数丢失请重新生成'));
return;
}
try {
setIsExchangingCode(true);
const res = await API.post('/api/channel/claude/oauth/exchange', {
authorization_code: authorizationCode,
code_verifier: oauthParams.code_verifier,
state: oauthParams.state,
});
if (res.data.success) {
const tokenData = res.data.data;
// 自动填充access token和refresh token
handleInputChange('access_token', tokenData.access_token);
handleInputChange('refresh_token', tokenData.refresh_token);
handleInputChange('key', `${tokenData.access_token}|${tokenData.refresh_token}`);
// 更新表单字段
if (formApiRef.current) {
formApiRef.current.setValue('access_token', tokenData.access_token);
formApiRef.current.setValue('refresh_token', tokenData.refresh_token);
}
setShowOAuthModal(false);
setAuthorizationCode('');
setOauthParams(null);
showSuccess(t('授权码交换成功已自动填充tokens'));
} else {
showError(res.data.message || t('授权码交换失败'));
}
} catch (error) {
showError(t('授权码交换失败:') + error.message);
} finally {
setIsExchangingCode(false);
}
};
useEffect(() => { useEffect(() => {
const modelMap = new Map(); const modelMap = new Map();
@@ -785,7 +880,7 @@ const EditChannelModal = (props) => {
const batchExtra = batchAllowed ? ( const batchExtra = batchAllowed ? (
<Space> <Space>
<Checkbox <Checkbox
disabled={isEdit} disabled={isEdit || inputs.type === 53}
checked={batch} checked={batch}
onChange={(e) => { onChange={(e) => {
const checked = e.target.checked; const checked = e.target.checked;
@@ -1121,6 +1216,49 @@ const EditChannelModal = (props) => {
/> />
)} )}
</> </>
) : inputs.type === 53 ? (
<>
<Form.Input
field='access_token'
label={isEdit ? t('Access Token编辑模式下保存的密钥不会显示') : t('Access Token')}
placeholder={t('sk-ant-xxx')}
rules={isEdit ? [] : [{ required: true, message: t('请输入Access Token') }]}
autoComplete='new-password'
onChange={(value) => {
handleInputChange('access_token', value);
// 同时更新key字段格式为access_token|refresh_token
const refreshToken = inputs.refresh_token || '';
handleInputChange('key', `${value}|${refreshToken}`);
}}
suffix={
<Button
size="small"
type="primary"
theme="light"
onClick={handleGenerateOAuth}
>
{t('生成OAuth授权码')}
</Button>
}
extraText={batchExtra}
showClear
/>
<Form.Input
field='refresh_token'
label={isEdit ? t('Refresh Token编辑模式下保存的密钥不会显示') : t('Refresh Token')}
placeholder={t('sk-ant-xxx可选')}
rules={[]}
autoComplete='new-password'
onChange={(value) => {
handleInputChange('refresh_token', value);
// 同时更新key字段格式为access_token|refresh_token
const accessToken = inputs.access_token || '';
handleInputChange('key', `${accessToken}|${value}`);
}}
extraText={batchExtra}
showClear
/>
</>
) : ( ) : (
<Form.Input <Form.Input
field='key' field='key'
@@ -1174,24 +1312,24 @@ const EditChannelModal = (props) => {
)} )}
{inputs.type === 41 && ( {inputs.type === 41 && (
<Form.TextArea <JSONEditor
field='other' field='other'
label={t('部署地区')} label={t('部署地区')}
placeholder={t( placeholder={t(
'请输入部署地区例如us-central1\n支持使用模型映射格式\n{\n "default": "us-central1",\n "claude-3-5-sonnet-20240620": "europe-west1"\n}' '请输入部署地区例如us-central1\n支持使用模型映射格式\n{\n "default": "us-central1",\n "claude-3-5-sonnet-20240620": "europe-west1"\n}'
)} )}
autosize value={inputs.other || ''}
onChange={(value) => handleInputChange('other', value)} onChange={(value) => handleInputChange('other', value)}
rules={[{ required: true, message: t('请填写部署地区') }]} rules={[{ required: true, message: t('请填写部署地区') }]}
template={REGION_EXAMPLE}
templateLabel={t('填入模板')}
editorType="region"
formApi={formApiRef.current}
extraText={ extraText={
<Text <Text type="tertiary" size="small">
className="!text-semi-color-primary cursor-pointer" {t('设置默认地区和特定模型的专用地区')}
onClick={() => handleInputChange('other', JSON.stringify(REGION_EXAMPLE, null, 2))}
>
{t('填入模板')}
</Text> </Text>
} }
showClear
/> />
)} )}
@@ -1447,24 +1585,24 @@ const EditChannelModal = (props) => {
showClear showClear
/> />
<Form.TextArea <JSONEditor
field='model_mapping' field='model_mapping'
label={t('模型重定向')} label={t('模型重定向')}
placeholder={ placeholder={
t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:') + t('此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:') +
`\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}` `\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`
} }
autosize value={inputs.model_mapping || ''}
onChange={(value) => handleInputChange('model_mapping', value)} onChange={(value) => handleInputChange('model_mapping', value)}
template={MODEL_MAPPING_EXAMPLE}
templateLabel={t('填入模板')}
editorType="keyValue"
formApi={formApiRef.current}
extraText={ extraText={
<Text <Text type="tertiary" size="small">
className="!text-semi-color-primary cursor-pointer" {t('键为请求中的模型名称,值为要替换的模型名称')}
onClick={() => handleInputChange('model_mapping', JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2))}
>
{t('填入模板')}
</Text> </Text>
} }
showClear
/> />
</Card> </Card>
@@ -1554,7 +1692,7 @@ const EditChannelModal = (props) => {
showClear showClear
/> />
<Form.TextArea <JSONEditor
field='status_code_mapping' field='status_code_mapping'
label={t('状态码复写')} label={t('状态码复写')}
placeholder={ placeholder={
@@ -1562,17 +1700,17 @@ const EditChannelModal = (props) => {
'\n' + '\n' +
JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2) JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)
} }
autosize value={inputs.status_code_mapping || ''}
onChange={(value) => handleInputChange('status_code_mapping', value)} onChange={(value) => handleInputChange('status_code_mapping', value)}
template={STATUS_CODE_MAPPING_EXAMPLE}
templateLabel={t('填入模板')}
editorType="keyValue"
formApi={formApiRef.current}
extraText={ extraText={
<Text <Text type="tertiary" size="small">
className="!text-semi-color-primary cursor-pointer" {t('键为原状态码,值为要复写的状态码,仅影响本地判断')}
onClick={() => handleInputChange('status_code_mapping', JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2))}
>
{t('填入模板')}
</Text> </Text>
} }
showClear
/> />
</Card> </Card>
@@ -1585,14 +1723,6 @@ const EditChannelModal = (props) => {
</Avatar> </Avatar>
<div> <div>
<Text className="text-lg font-medium">{t('渠道额外设置')}</Text> <Text className="text-lg font-medium">{t('渠道额外设置')}</Text>
<div className="text-xs text-gray-600">
<Text
className="!text-semi-color-primary cursor-pointer"
onClick={() => window.open('https://github.com/QuantumNous/new-api/blob/main/docs/channel/other_setting.md')}
>
{t('设置说明')}
</Text>
</div>
</div> </div>
</div> </div>
@@ -1637,11 +1767,19 @@ const EditChannelModal = (props) => {
<Form.TextArea <Form.TextArea
field='system_prompt' field='system_prompt'
label={t('系统提示词')} label={t('系统提示词')}
placeholder={t('输入系统提示词,用户的系统提示词将优先于此设置')} placeholder={inputs.type === 53 ? CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT : t('输入系统提示词,用户的系统提示词将优先于此设置')}
onChange={(value) => handleChannelSettingsChange('system_prompt', value)} onChange={(value) => {
if (inputs.type === 53) {
// Claude Code渠道系统提示词固定不允许修改
return;
}
handleChannelSettingsChange('system_prompt', value);
}}
disabled={inputs.type === 53}
value={inputs.type === 53 ? CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT : undefined}
autosize autosize
showClear showClear={inputs.type !== 53}
extraText={t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')} extraText={inputs.type === 53 ? t('Claude Code渠道系统提示词固定为官方CLI身份不可修改') : t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')}
/> />
</Card> </Card>
</div> </div>
@@ -1665,8 +1803,70 @@ const EditChannelModal = (props) => {
}} }}
onCancel={() => setModelModalVisible(false)} onCancel={() => setModelModalVisible(false)}
/> />
{/* OAuth Authorization Modal */}
<Modal
title={t('生成Claude Code OAuth授权码')}
visible={showOAuthModal}
onCancel={() => {
setShowOAuthModal(false);
setAuthorizationCode('');
setOauthParams(null);
}}
onOk={handleExchangeCode}
okText={isExchangingCode ? t('交换中...') : t('确认')}
cancelText={t('取消')}
confirmLoading={isExchangingCode}
width={600}
>
<div className="space-y-4">
<div>
<Text className="text-sm font-medium mb-2 block">{t('请访问以下授权地址:')}</Text>
<div className="p-3 bg-gray-50 rounded-lg border">
<Text
link
underline
className="text-sm font-mono break-all cursor-pointer text-blue-600 hover:text-blue-800"
onClick={() => {
if (oauthParams?.auth_url) {
window.open(oauthParams.auth_url, '_blank');
}
}}
>
{oauthParams?.auth_url || t('正在生成授权地址...')}
</Text>
<div className="mt-2">
<Text
copyable={{ content: oauthParams?.auth_url }}
type="tertiary"
size="small"
>
{t('复制链接')}
</Text>
</div>
</div>
</div>
<div>
<Text className="text-sm font-medium mb-2 block">{t('授权后,请将获得的授权码粘贴到下方:')}</Text>
<Input
value={authorizationCode}
onChange={setAuthorizationCode}
placeholder={t('请输入授权码')}
showClear
style={{ width: '100%' }}
/>
</div>
<Banner
type="info"
description={t('获得授权码后系统将自动换取access token和refresh token并填充到表单中。')}
className="!rounded-lg"
/>
</div>
</Modal>
</> </>
); );
}; };
export default EditChannelModal; export default EditChannelModal;

View File

@@ -159,6 +159,14 @@ export const CHANNEL_OPTIONS = [
color: 'purple', color: 'purple',
label: 'Vidu', label: 'Vidu',
}, },
{
value: 53,
color: 'indigo',
label: 'Claude Code',
},
]; ];
export const MODEL_TABLE_PAGE_SIZE = 10; export const MODEL_TABLE_PAGE_SIZE = 10;
// Claude Code 相关常量
export const CLAUDE_CODE_DEFAULT_SYSTEM_PROMPT = "You are Claude Code, Anthropic's official CLI for Claude.";

View File

@@ -367,6 +367,7 @@ export function getChannelIcon(channelType) {
return <Ollama size={iconSize} />; return <Ollama size={iconSize} />;
case 14: // Anthropic Claude case 14: // Anthropic Claude
case 33: // AWS Claude case 33: // AWS Claude
case 53: // Claude Code
return <Claude.Color size={iconSize} />; return <Claude.Color size={iconSize} />;
case 41: // Vertex AI case 41: // Vertex AI
return <Gemini.Color size={iconSize} />; return <Gemini.Color size={iconSize} />;