Merge branch 'main' of https://github.com/mt21625457/aicodex2api
This commit is contained in:
@@ -20,7 +20,6 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
@@ -375,7 +374,8 @@ type ForwardResult struct {
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
@@ -389,6 +389,7 @@ type GatewayService struct {
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
@@ -410,6 +411,7 @@ func NewGatewayService(
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
@@ -429,6 +431,7 @@ func NewGatewayService(
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
@@ -624,35 +627,6 @@ func stripToolPrefix(value string) string {
|
||||
return toolPrefixRe.ReplaceAllString(value, "")
|
||||
}
|
||||
|
||||
func toPascalCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
|
||||
tokens := make([]string, 0)
|
||||
for _, token := range strings.Fields(normalized) {
|
||||
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
|
||||
parts := strings.Fields(expanded)
|
||||
if len(parts) > 0 {
|
||||
tokens = append(tokens, parts...)
|
||||
}
|
||||
}
|
||||
if len(tokens) == 0 {
|
||||
return value
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, token := range tokens {
|
||||
lower := strings.ToLower(token)
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
runes := []rune(lower)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
_, _ = builder.WriteString(string(runes))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func toSnakeCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
@@ -668,16 +642,15 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
// 只对已知的工具名进行映射,未知工具名保持原样
|
||||
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728)
|
||||
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
|
||||
if !ok {
|
||||
mapped = toPascalCase(stripped)
|
||||
}
|
||||
if mapped != "" && cache != nil && mapped != stripped {
|
||||
cache[mapped] = stripped
|
||||
}
|
||||
if mapped == "" {
|
||||
return stripped
|
||||
}
|
||||
if cache != nil && mapped != stripped {
|
||||
cache[mapped] = stripped
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
|
||||
@@ -686,15 +659,18 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
// 优先从请求时建立的映射中查找
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
// 已知工具名的硬编码映射
|
||||
if mapped, ok := openCodeToolOverrides[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
return toSnakeCase(stripped)
|
||||
// 未知工具名保持原样,避免破坏 Anthropic 特殊工具
|
||||
return stripped
|
||||
}
|
||||
|
||||
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
|
||||
@@ -3313,7 +3289,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -3343,10 +3319,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
@@ -3390,7 +3364,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -3787,6 +3761,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||
return extractUpstreamErrorMessage(body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorMessage(body []byte) string {
|
||||
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||
@@ -3854,7 +3834,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
|
||||
@@ -4641,10 +4621,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -4827,10 +4814,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
|
||||
Reference in New Issue
Block a user