fix(gateway): 修复 Claude Code 客户端检测和请求信息记录

- 在 Messages 方法中调用 SetClaudeCodeClientContext 启用客户端检测
- 修复 RecordUsageInput 未传递 UserAgent 和 IPAddress 的问题
This commit is contained in:
shaw
2026-01-12 15:19:40 +08:00
parent 9c144587fe
commit cf313d5761
3 changed files with 35 additions and 8 deletions

View File

@@ -88,6 +88,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body) setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body) parsedReq, err := service.ParseGatewayRequest(body)
@@ -286,8 +289,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := c.ClientIP()
// 异步记录使用量subscription已在函数开头获取 // 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -296,10 +303,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account) }(result, account, userAgent, clientIP)
return return
} }
} }
@@ -414,8 +423,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := c.ClientIP()
// 异步记录使用量subscription已在函数开头获取 // 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -424,10 +437,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account) }(result, account, userAgent, clientIP)
return return
} }
} }

View File

@@ -314,8 +314,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return return
} }
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := c.ClientIP()
// 6) record usage async // 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account) { go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -324,10 +328,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account) }(result, account, userAgent, clientIP)
return return
} }
} }

View File

@@ -263,8 +263,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context
userAgent := c.GetHeader("User-Agent")
clientIP := c.ClientIP()
// Async record usage // Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) { go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
@@ -273,10 +277,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
User: apiKey.User, User: apiKey.User,
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua,
IPAddress: ip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account) }(result, account, userAgent, clientIP)
return return
} }
} }