diff --git a/constant/context_key.go b/constant/context_key.go index 4eaf3d00..32dd9617 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -11,7 +11,6 @@ const ( ContextKeyTokenKey ContextKey = "token_key" ContextKeyTokenId ContextKey = "token_id" ContextKeyTokenGroup ContextKey = "token_group" - ContextKeyTokenAllowIps ContextKey = "allow_ips" ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" ContextKeyTokenModelLimit ContextKey = "token_model_limit" diff --git a/go.mod b/go.mod index ae57a38d..91128b5b 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,10 @@ require ( github.com/Calcium-Ion/go-epay v0.0.4 github.com/andybalholm/brotli v1.1.1 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 - github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2 v1.37.2 github.com/aws/aws-sdk-go-v2/credentials v1.17.11 - github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 + github.com/aws/smithy-go v1.22.5 github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 @@ -24,6 +25,7 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 + github.com/pquerna/otp v1.5.0 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible github.com/shopspring/decimal v1.4.0 @@ -42,10 +44,9 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect - github.com/aws/smithy-go v1.20.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect @@ -81,7 +82,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect - github.com/pquerna/otp v1.5.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 6dffcff5..eed0ae46 100644 --- a/go.sum +++ b/go.sum @@ -8,21 +8,20 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= -github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= -github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo= +github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg= github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= -github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= -github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= -github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA= +github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= +github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= diff --git a/middleware/auth.go b/middleware/auth.go index 72900f83..5f6e5d43 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,7 +4,10 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/model" + "one-api/setting" + "one-api/setting/ratio_setting" "strconv" "strings" @@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } + + allowIpsMap := token.GetIpLimitsMap() + if len(allowIpsMap) != 0 { + clientIp := c.ClientIP() + if _, ok := allowIpsMap[clientIp]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") + return + } + } + userCache, err := model.GetUserCache(token.UserId) if err != nil { abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) @@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) { userCache.WriteContext(c) + userGroup := userCache.Group + tokenGroup := token.Group + if tokenGroup != "" { + // check common.UserUsableGroups[userGroup] + if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) + return + } + // check group in common.GroupRatio + if !ratio_setting.ContainsGroupRatio(tokenGroup) { + if tokenGroup != "auto" { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) + return + } + } + userGroup = tokenGroup + } + common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) + err = SetupContextForToken(c, token, parts...) if err != nil { return @@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e } else { c.Set("token_model_limit_enabled", false) } - c.Set("allow_ips", token.GetIpLimitsMap()) c.Set("token_group", token.Group) if len(parts) > 1 { if model.IsAdmin(token.UserId) { diff --git a/middleware/distributor.go b/middleware/distributor.go index c7a55f4c..5fae6322 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -10,7 +10,6 @@ import ( "one-api/model" relayconstant "one-api/relay/constant" "one-api/service" - "one-api/setting" "one-api/setting/ratio_setting" "one-api/types" "strconv" @@ -27,14 +26,6 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { - allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps) - if len(allowIpsMap) != 0 { - clientIp := c.ClientIP() - if _, ok := allowIpsMap[clientIp]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") - return - } - } var channel *model.Channel channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) modelRequest, shouldSelectChannel, err := getModelRequest(c) @@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) - tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) - if tokenGroup != "" { - // check common.UserUsableGroups[userGroup] - if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) - return - } - // check group in common.GroupRatio - if !ratio_setting.ContainsGroupRatio(tokenGroup) { - if tokenGroup != "auto" { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) - return - } - } - userGroup = tokenGroup - } - common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -81,22 +54,21 @@ func Distribute() func(c *gin.Context) { modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) - var tokenModelLimit map[string]bool - if ok { - tokenModelLimit = s.(map[string]bool) - } else { - tokenModelLimit = map[string]bool{} - } - if tokenModelLimit != nil { - if _, ok := tokenModelLimit[modelRequest.Model]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) - return - } - } else { + if !ok { // token model limit is empty, all models are not allowed abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } + var tokenModelLimit map[string]bool + tokenModelLimit, ok = s.(map[string]bool) + if !ok { + tokenModelLimit = map[string]bool{} + } + matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-* + if _, ok := tokenModelLimit[matchName]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) + return + } } if shouldSelectChannel { @@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) { return } var selectGroup string + userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) if err != nil { showGroup := userGroup diff --git a/model/channel_cache.go b/model/channel_cache.go index 6ca23cf9..90bd2ad1 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/constant" "one-api/setting" + "one-api/setting/ratio_setting" "sort" "strings" "sync" @@ -128,12 +129,7 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, } func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { - if strings.HasPrefix(model, "gpt-4-gizmo") { - model = "gpt-4-gizmo-*" - } - if strings.HasPrefix(model, "gpt-4o-gizmo") { - model = "gpt-4o-gizmo-*" - } + model = ratio_setting.FormatMatchingModelName(model) // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 64c7b747..3f8800b1 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -13,6 +13,7 @@ var awsModelIDMap = map[string]string{ "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", + "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -54,6 +55,9 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{ "anthropic.claude-opus-4-20250514-v1:0": { "us": true, }, + "anthropic.claude-opus-4-1-20250805-v1:0": { + "us": true, + }, } var awsRegionCrossModelPrefixMap = map[string]string{ diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 0df19e07..04427ab8 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -19,20 +19,31 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/aws/smithy-go/auth/bearer" ) func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) { awsSecret := strings.Split(info.ApiKey, "|") - if len(awsSecret) != 3 { + var client *bedrockruntime.Client + switch len(awsSecret) { + case 2: + apiKey := awsSecret[0] + region := awsSecret[1] + client = bedrockruntime.New(bedrockruntime.Options{ + Region: region, + BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}}, + }) + case 3: + ak := awsSecret[0] + sk := awsSecret[1] + region := awsSecret[2] + client = bedrockruntime.New(bedrockruntime.Options{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), + }) + default: return nil, errors.New("invalid aws secret key") } - ak := awsSecret[0] - sk := awsSecret[1] - region := awsSecret[2] - client := bedrockruntime.New(bedrockruntime.Options{ - Region: region, - Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), - }) return client, nil } diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go index e0e3c421..a23543d2 100644 --- a/relay/channel/claude/constants.go +++ b/relay/channel/claude/constants.go @@ -17,6 +17,8 @@ var ModelList = []string{ "claude-sonnet-4-20250514-thinking", "claude-opus-4-20250514", "claude-opus-4-20250514-thinking", + "claude-opus-4-1-20250805", + "claude-opus-4-1-20250805-thinking", } var ChannelName = "claude" diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index adc771e2..18524afb 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -49,12 +49,20 @@ const ( flash25LiteMaxBudget = 24576 ) -// clampThinkingBudget 根据模型名称将预算限制在允许的范围内 -func clampThinkingBudget(modelName string, budget int) int { - isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && +func isNew25ProModel(modelName string) bool { + return strings.HasPrefix(modelName, "gemini-2.5-pro") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") - is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite") +} + +func is25FlashLiteModel(modelName string) bool { + return strings.HasPrefix(modelName, "gemini-2.5-flash-lite") +} + +// clampThinkingBudget 根据模型名称将预算限制在允许的范围内 +func clampThinkingBudget(modelName string, budget int) int { + isNew25Pro := isNew25ProModel(modelName) + is25FlashLite := is25FlashLiteModel(modelName) if is25FlashLite { if budget < flash25LiteMinBudget { @@ -81,7 +89,34 @@ func clampThinkingBudget(modelName string, budget int) int { return budget } -func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) { +// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens) +// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens) +// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens) +func clampThinkingBudgetByEffort(modelName string, effort string) int { + isNew25Pro := isNew25ProModel(modelName) + is25FlashLite := is25FlashLiteModel(modelName) + + maxBudget := 0 + if is25FlashLite { + maxBudget = flash25LiteMaxBudget + } + if isNew25Pro { + maxBudget = pro25MaxBudget + } else { + maxBudget = flash25MaxBudget + } + switch effort { + case "high": + maxBudget = maxBudget * 80 / 100 + case "medium": + maxBudget = maxBudget * 50 / 100 + case "low": + maxBudget = maxBudget * 20 / 100 + } + return clampThinkingBudget(modelName, maxBudget) +} + +func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { modelName := info.UpstreamModelName isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && @@ -124,6 +159,11 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) clampedBudget := clampThinkingBudget(modelName, int(budgetTokens)) geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget) + } else { + if len(oaiRequest) > 0 { + // 如果有reasoningEffort参数,则根据其值设置思考预算 + geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort)) + } } } } else if strings.HasSuffix(modelName, "-nothinking") { @@ -156,7 +196,37 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } } - ThinkingAdaptor(&geminiRequest, info) + adaptorWithExtraBody := false + + if len(textRequest.ExtraBody) > 0 { + if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") { + var extraBody map[string]interface{} + if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil { + return nil, fmt.Errorf("invalid extra body: %w", err) + } + // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}} + if googleBody, ok := extraBody["google"].(map[string]interface{}); ok { + adaptorWithExtraBody = true + if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok { + if budget, ok := thinkingConfig["thinking_budget"].(float64); ok { + budgetInt := int(budget) + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(budgetInt), + IncludeThoughts: true, + } + } else { + geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{ + IncludeThoughts: true, + } + } + } + } + } + } + + if !adaptorWithExtraBody { + ThinkingAdaptor(&geminiRequest, info, textRequest) + } safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList)) for _, category := range SafetySettingList { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index f46af710..115b7b76 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -172,6 +173,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if len(request.Usage) == 0 { request.Usage = json.RawMessage(`{"include":true}`) } + if strings.HasSuffix(info.UpstreamModelName, "-thinking") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") + request.Model = info.UpstreamModelName + if len(request.Reasoning) == 0 { + reasoning := map[string]any{ + "enabled": true, + } + if request.ReasoningEffort != "" { + reasoning["effort"] = request.ReasoningEffort + } + marshal, err := common.Marshal(reasoning) + if err != nil { + return nil, fmt.Errorf("error marshalling reasoning: %w", err) + } + request.Reasoning = marshal + } + } } if strings.HasPrefix(request.Model, "o") { if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index ef063e7c..bae6fcb6 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -37,9 +37,14 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http // compute usage usage := dto.Usage{} - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens - usage.TotalTokens = responsesResponse.Usage.TotalTokens + if responsesResponse.Usage != nil { + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + if responsesResponse.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens + } + } // 解析 Tools 用量 for _, tool := range responsesResponse.Tools { info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++ @@ -64,9 +69,14 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + if streamResponse.Response.Usage != nil { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + if streamResponse.Response.Usage.InputTokensDetails != nil { + usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens + } + } case "response.output_text.delta": // 处理输出文本 responseTextBuilder.WriteString(streamResponse.Delta) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 9b62cffc..4648a384 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -35,6 +35,7 @@ var claudeModelMap = map[string]string{ "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219", "claude-sonnet-4-20250514": "claude-sonnet-4@20250514", "claude-opus-4-20250514": "claude-opus-4@20250514", + "claude-opus-4-1-20250805": "claude-opus-4-1@20250805", } const anthropicVersion = "vertex-2023-10-16" diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 51d473a8..8b87cb86 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -40,6 +40,8 @@ var defaultCacheRatio = map[string]float64{ "claude-sonnet-4-20250514-thinking": 0.1, "claude-opus-4-20250514": 0.1, "claude-opus-4-20250514-thinking": 0.1, + "claude-opus-4-1-20250805": 0.1, + "claude-opus-4-1-20250805-thinking": 0.1, } var defaultCreateCacheRatio = map[string]float64{ @@ -55,6 +57,8 @@ var defaultCreateCacheRatio = map[string]float64{ "claude-sonnet-4-20250514-thinking": 1.25, "claude-opus-4-20250514": 1.25, "claude-opus-4-20250514-thinking": 1.25, + "claude-opus-4-1-20250805": 1.25, + "claude-opus-4-1-20250805-thinking": 1.25, } //var defaultCreateCacheRatio = map[string]float64{} diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 8a1d6aae..647cc1f4 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -118,6 +118,7 @@ var defaultModelRatio = map[string]float64{ "claude-sonnet-4-20250514": 1.5, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-opus-4-20250514": 7.5, + "claude-opus-4-1-20250805": 7.5, "ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB, @@ -334,12 +335,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } + name = FormatMatchingModelName(name) + price, ok := modelPriceMap[name] if !ok { if printErr { @@ -373,11 +370,8 @@ func GetModelRatio(name string) (float64, bool, string) { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") - name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } + name = FormatMatchingModelName(name) + ratio, ok := modelRatioMap[name] if !ok { return 37.5, operation_setting.SelfUseModeEnabled, name @@ -428,12 +422,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { func GetCompletionRatio(name string) float64 { CompletionRatioMutex.RLock() defer CompletionRatioMutex.RUnlock() - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } + + name = FormatMatchingModelName(name) + if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -663,3 +654,16 @@ func GetCompletionRatioCopy() map[string]float64 { } return copyMap } + +// 转换模型名,减少渠道必须配置各种带参数模型 +func FormatMatchingModelName(name string) string { + name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") + name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") + if strings.HasPrefix(name, "gpt-4-gizmo") { + name = "gpt-4-gizmo-*" + } + if strings.HasPrefix(name, "gpt-4o-gizmo") { + name = "gpt-4o-gizmo-*" + } + return name +} diff --git a/types/error.go b/types/error.go index e7265e21..d3dd29e1 100644 --- a/types/error.go +++ b/types/error.go @@ -189,9 +189,13 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI } func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { + if errorCode == ErrorCodeDoRequestFailed { + err = errors.New("upstream error: do request failed") + } openaiError := OpenAIError{ Message: err.Error(), Type: string(errorCode), + Code: errorCode, } return WithOpenAIError(openaiError, statusCode, ops...) } @@ -199,6 +203,7 @@ func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAP func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { openaiError := OpenAIError{ Type: string(errorCode), + Code: errorCode, } return WithOpenAIError(openaiError, statusCode, ops...) } @@ -224,7 +229,11 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { code, ok := openAIError.Code.(string) if !ok { - code = fmt.Sprintf("%v", openAIError.Code) + if openAIError.Code == nil { + code = fmt.Sprintf("%v", openAIError.Code) + } else { + code = "unknown_error" + } } if openAIError.Type == "" { openAIError.Type = "upstream_error" diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index ef946890..1186c19f 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -1182,6 +1182,7 @@ export function renderLogContent( modelPrice = -1, groupRatio, user_group_ratio, + cacheRatio = 1.0, image = false, imageRatio = 1.0, webSearch = false, @@ -1200,9 +1201,10 @@ export function renderLogContent( } else { if (image) { return i18next.t( - '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}},{{ratioType}} {{ratio}}', + '模型倍率 {{modelRatio}},缓存倍率 {{cacheRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}},{{ratioType}} {{ratio}}', { modelRatio: modelRatio, + cacheRatio: cacheRatio, completionRatio: completionRatio, imageRatio: imageRatio, ratioType: ratioLabel, @@ -1211,9 +1213,10 @@ export function renderLogContent( ); } else if (webSearch) { return i18next.t( - '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}},Web 搜索调用 {{webSearchCallCount}} 次', + '模型倍率 {{modelRatio}},缓存倍率 {{cacheRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}},Web 搜索调用 {{webSearchCallCount}} 次', { modelRatio: modelRatio, + cacheRatio: cacheRatio, completionRatio: completionRatio, ratioType: ratioLabel, ratio, @@ -1222,9 +1225,10 @@ export function renderLogContent( ); } else { return i18next.t( - '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}}', + '模型倍率 {{modelRatio}},缓存倍率 {{cacheRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}}', { modelRatio: modelRatio, + cacheRatio: cacheRatio, completionRatio: completionRatio, ratioType: ratioLabel, ratio, diff --git a/web/src/hooks/usage-logs/useUsageLogsData.js b/web/src/hooks/usage-logs/useUsageLogsData.js index 03e09eb8..0c6c4452 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.js +++ b/web/src/hooks/usage-logs/useUsageLogsData.js @@ -366,6 +366,7 @@ export const useLogsData = () => { other.model_price, other.group_ratio, other?.user_group_ratio, + other.cache_ratio || 1.0, false, 1.0, other.web_search || false, diff --git a/web/src/pages/Setting/Ratio/ModelSettingsVisualEditor.js b/web/src/pages/Setting/Ratio/ModelSettingsVisualEditor.js index 2aa45ace..1205f6d8 100644 --- a/web/src/pages/Setting/Ratio/ModelSettingsVisualEditor.js +++ b/web/src/pages/Setting/Ratio/ModelSettingsVisualEditor.js @@ -44,6 +44,7 @@ export default function ModelSettingsVisualEditor(props) { const { t } = useTranslation(); const [models, setModels] = useState([]); const [visible, setVisible] = useState(false); + const [isEditMode, setIsEditMode] = useState(false); const [currentModel, setCurrentModel] = useState(null); const [searchText, setSearchText] = useState(''); const [currentPage, setCurrentPage] = useState(1); @@ -386,9 +387,11 @@ export default function ModelSettingsVisualEditor(props) { setCurrentModel(null); setPricingMode('per-token'); setPricingSubMode('ratio'); + setIsEditMode(false); }; const editModel = (record) => { + setIsEditMode(true); // Determine which pricing mode to use based on the model's current configuration let initialPricingMode = 'per-token'; let initialPricingSubMode = 'ratio'; @@ -500,13 +503,7 @@ export default function ModelSettingsVisualEditor(props) { model.name === currentModel.name) - ? t('编辑模型') - : t('添加模型') - } + title={isEditMode ? t('编辑模型') : t('添加模型')} visible={visible} onCancel={() => { resetModalState(); @@ -562,11 +559,7 @@ export default function ModelSettingsVisualEditor(props) { label={t('模型名称')} placeholder='strawberry' required - disabled={ - currentModel && - currentModel.name && - models.some((model) => model.name === currentModel.name) - } + disabled={isEditMode} onChange={(value) => setCurrentModel((prev) => ({ ...prev, name: value })) }