diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index e620735d..842fe06e 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -443,3 +443,69 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess }, }) } + +// CountTokens handles token counting endpoint +// POST /v1/messages/count_tokens +// 特点:校验订阅/余额,但不计算并发、不记录使用量 +func (h *GatewayHandler) CountTokens(c *gin.Context) { + // 从context获取apiKey和user(ApiKeyAuth中间件已设置) + apiKey, ok := middleware.GetApiKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + user, ok := middleware.GetUserFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + // 读取请求体 + body, err := io.ReadAll(c.Request.Body) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + // 解析请求获取模型名 + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(body, &req); err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // 获取订阅信息(可能为nil) + subscription, _ := middleware.GetSubscriptionFromContext(c) + + // 校验 billing eligibility(订阅/余额) + // 【注意】不计算并发,但需要校验订阅/余额 + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { + h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error()) + return + } + + // 计算粘性会话 hash + sessionHash := h.gatewayService.GenerateSessionHash(body) + + // 选择支持该模型的账号 + account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + + // 转发请求(不记录使用量) + if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil { + log.Printf("Forward count_tokens request failed: %v", err) + // 错误响应已在 ForwardCountTokens 中处理 + return + } +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index caa122eb..04f5bfca 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -281,6 +281,7 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription)) { gateway.POST("/messages", h.Gateway.Messages) + gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) gateway.GET("/models", h.Gateway.Models) gateway.GET("/usage", h.Gateway.Usage) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 24f4a9ed..55dfc784 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -27,10 +27,11 @@ import ( ) const ( - claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" - stickySessionPrefix = "sticky_session:" - stickySessionTTL = time.Hour // 粘性会话TTL - tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token + claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" + claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" + stickySessionPrefix = "sticky_session:" + stickySessionTTL = time.Hour // 粘性会话TTL + tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token ) // allowedHeaders 白名单headers(参考CRS项目) @@ -1044,3 +1045,205 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } + +// ForwardCountTokens 转发 count_tokens 请求到上游 API +// 特点:不记录使用量、仅支持非流式响应 +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error { + // 应用模型映射(仅对 apikey 类型账号) + if account.Type == model.AccountTypeApiKey { + var req struct { + Model string `json:"model"` + } + if err := json.Unmarshal(body, &req); err == nil && req.Model != "" { + mappedModel := account.GetMappedModel(req.Model) + if mappedModel != req.Model { + body = s.replaceModelInBody(body, mappedModel) + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name) + } + } + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + + // 构建上游请求 + upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + // 选择 HTTP client + httpClient := s.httpClient + if upstreamResult.Client != nil { + httpClient = upstreamResult.Client + } + + // 发送请求 + resp, err := httpClient.Do(upstreamResult.Request) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + defer resp.Body.Close() + + // 处理 401 错误:刷新 token 重试(仅 OAuth) + if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" { + resp.Body.Close() + token, tokenType, err = s.forceRefreshToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed") + return fmt.Errorf("token refresh failed: %w", err) + } + upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) + if err != nil { + return err + } + httpClient = s.httpClient + if upstreamResult.Client != nil { + httpClient = upstreamResult.Client + } + resp, err = httpClient.Do(upstreamResult.Request) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed") + return fmt.Errorf("retry request failed: %w", err) + } + defer resp.Body.Close() + } + + // 读取响应体 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + // 处理错误响应 + if resp.StatusCode >= 400 { + // 标记账号状态(429/529等) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + // 返回简化的错误响应 + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + + // 透传成功响应 + c.Data(resp.StatusCode, "application/json", respBody) + return nil +} + +// buildCountTokensRequest 构建 count_tokens 上游请求 +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) { + // 确定目标 URL + targetURL := claudeAPICountTokensURL + if account.Type == model.AccountTypeApiKey { + baseURL := account.GetBaseURL() + targetURL = baseURL + "/v1/messages/count_tokens" + } + + // OAuth 账号:应用统一指纹和重写 userID + if account.IsOAuth() && s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err == nil { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头 + if tokenType == "oauth" { + req.Header.Set("Authorization", "Bearer "+token) + } else { + req.Header.Set("x-api-key", token) + } + + // 白名单透传 headers + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // OAuth 账号:应用指纹到请求头 + if account.IsOAuth() && s.identityService != nil { + fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if fp != nil { + s.identityService.ApplyFingerprint(req, fp) + } + } + + // 确保必要的 headers 存在 + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + // OAuth 账号:处理 anthropic-beta header + if tokenType == "oauth" { + req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) + } + + // 配置代理 + var customClient *http.Client + if account.ProxyID != nil && account.Proxy != nil { + proxyURL := account.Proxy.URL() + if proxyURL != "" { + if parsedURL, err := url.Parse(proxyURL); err == nil { + responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second + if responseHeaderTimeout == 0 { + responseHeaderTimeout = 300 * time.Second + } + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedURL), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + ResponseHeaderTimeout: responseHeaderTimeout, + } + customClient = &http.Client{Transport: transport} + } + } + } + + return &buildUpstreamRequestResult{ + Request: req, + Client: customClient, + }, nil +} + +// countTokensError 返回 count_tokens 错误响应 +func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +}