Merge branch 'main' into test

This commit is contained in:
yangjianbo
2026-02-04 20:35:09 +08:00
90 changed files with 5735 additions and 749 deletions

View File

@@ -257,6 +257,9 @@ var (
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
@@ -589,12 +592,18 @@ func (s *GatewayService) hashContent(content string) string {
}
// replaceModelInBody 替换请求体中的model字段
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]any
var req map[string]json.RawMessage
if err := json.Unmarshal(body, &req); err != nil {
return body
}
req["model"] = newModel
// 只序列化 model 字段
modelBytes, err := json.Marshal(newModel)
if err != nil {
return body
}
req["model"] = modelBytes
newBody, err := json.Marshal(req)
if err != nil {
return body
@@ -791,12 +800,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if len(body) == 0 {
return body, modelID, nil
}
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
var reqRaw map[string]json.RawMessage
if err := json.Unmarshal(body, &reqRaw); err != nil {
return body, modelID, nil
}
// 同时解析为 map[string]any 用于修改非 messages 字段
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body, modelID, nil
}
toolNameMap := make(map[string]string)
modified := false
if system, ok := req["system"]; ok {
switch v := system.(type) {
@@ -804,6 +822,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized := sanitizeSystemText(v)
if sanitized != v {
req["system"] = sanitized
modified = true
}
case []any:
for _, item := range v {
@@ -821,6 +840,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
sanitized := sanitizeSystemText(text)
if sanitized != text {
block["text"] = sanitized
modified = true
}
}
}
@@ -831,6 +851,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if normalized != rawModel {
req["model"] = normalized
modelID = normalized
modified = true
}
}
@@ -846,16 +867,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
toolMap["name"] = normalized
modified = true
}
}
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
modified = true
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
modified = true
}
tools[idx] = toolMap
}
@@ -884,11 +908,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalizedTools[normalized] = value
}
req["tools"] = normalizedTools
modified = true
}
} else {
req["tools"] = []any{}
modified = true
}
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
messagesModified := false
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
@@ -899,6 +927,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if !ok {
continue
}
// 检查此消息是否包含 thinking 块
hasThinking := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
hasThinking = true
break
}
}
// 如果包含 thinking 块,跳过此消息的修改
if hasThinking {
continue
}
// 只修改不包含 thinking 块的消息中的 tool_use
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
@@ -911,6 +957,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized := normalizeToolNameForClaude(name, toolNameMap)
if normalized != "" && normalized != name {
blockMap["name"] = normalized
messagesModified = true
}
}
}
@@ -920,6 +967,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if opts.stripSystemCacheControl {
if system, ok := req["system"]; ok {
_ = stripCacheControlFromSystemBlocks(system)
modified = true
}
}
@@ -931,12 +979,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
metadata["user_id"] = opts.metadataUserID
modified = true
}
}
delete(req, "temperature")
delete(req, "tool_choice")
if _, hasTemp := req["temperature"]; hasTemp {
delete(req, "temperature")
modified = true
}
if _, hasChoice := req["tool_choice"]; hasChoice {
delete(req, "tool_choice")
modified = true
}
if !modified && !messagesModified {
return body, modelID, toolNameMap
}
// 如果 messages 没有被修改,保留原始 messages 字节
if !messagesModified {
// 序列化非 messages 字段
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
}
// 替换回原始的 messages
var newReq map[string]json.RawMessage
if err := json.Unmarshal(newBody, &newReq); err != nil {
return newBody, modelID, toolNameMap
}
if origMessages, ok := reqRaw["messages"]; ok {
newReq["messages"] = origMessages
}
finalBody, err := json.Marshal(newReq)
if err != nil {
return newBody, modelID, toolNameMap
}
return finalBody, modelID, toolNameMap
}
// messages 被修改了,需要完整序列化
newBody, err := json.Marshal(req)
if err != nil {
return body, modelID, toolNameMap
@@ -1139,6 +1221,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
@@ -1636,6 +1725,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return group, nil
}
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
return s.resolveGroupByID(ctx, groupID)
}
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
return nil
@@ -1701,7 +1794,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
return nil, groupID, nil
}
@@ -2030,6 +2123,13 @@ func shuffleWithinPriority(accounts []*Account) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
@@ -2465,6 +2565,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
return true
@@ -2914,16 +3018,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射仅对apikey类型账号
// 应用模型映射
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射短ID → 长ID
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
@@ -3625,6 +3743,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
return true
}
// 检测 thinking block 被修改的错误
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block modification error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
@@ -4493,13 +4618,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater // 可选用于更新API Key配额
}
// APIKeyQuotaUpdater defines the interface for updating API Key quota
type APIKeyQuotaUpdater interface {
UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -4661,6 +4792,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
}
// 更新 API Key 配额(如果设置了配额限制)
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
log.Printf("Update API key quota failed: %v", err)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
@@ -4678,6 +4816,7 @@ type RecordUsageLongContextInput struct {
IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0
APIKeyService *APIKeyService // API Key 配额服务(可选)
}
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini
@@ -4814,6 +4953,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
// API Key 独立配额扣费
if input.APIKeyService != nil && apiKey.Quota > 0 {
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
log.Printf("Add API key quota used failed: %v", err)
}
}
}
}
@@ -4848,16 +4993,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
// 应用模型映射
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射短ID → 长ID
if reqModel != "" {
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
}
}
// 获取凭证
@@ -5109,6 +5268,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {