feat(gateway): 添加 /v1/messages/count_tokens 端点

实现 Claude API 的 token 计数功能,支持 OAuth、SetupToken 和 ApiKey 三种账号类型。

特点:
- 校验订阅/余额(不扣费)
- 不计算用户和账号并发
- 不记录使用量
- 支持模型映射(ApiKey 账号)
- 支持 OAuth 账号的指纹管理和 401 重试
This commit is contained in:
shaw
2025-12-19 11:12:41 +08:00
parent 2e59998c51
commit 52d2ae9708
3 changed files with 274 additions and 4 deletions

View File

@@ -443,3 +443,69 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
},
})
}
// CountTokens handles token counting endpoint
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 获取订阅信息可能为nil
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
}
}