package service import ( "bufio" "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "log" "net/http" "net/url" "regexp" "strconv" "strings" "time" "sub2api/internal/config" "sub2api/internal/model" "sub2api/internal/repository" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" ) const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" stickySessionPrefix = "sticky_session:" stickySessionTTL = time.Hour // 粘性会话TTL tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token ) // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, "x-stainless-retry-count": true, "x-stainless-timeout": true, "x-stainless-lang": true, "x-stainless-package-version": true, "x-stainless-os": true, "x-stainless-arch": true, "x-stainless-runtime": true, "x-stainless-runtime-version": true, "x-stainless-helper-method": true, "anthropic-dangerous-direct-browser-access": true, "anthropic-version": true, "x-app": true, "anthropic-beta": true, "accept-language": true, "sec-fetch-mode": true, "accept-encoding": true, "user-agent": true, "content-type": true, } // ClaudeUsage 表示Claude API返回的usage信息 type ClaudeUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` } // ForwardResult 转发结果 type ForwardResult struct { RequestID string Usage ClaudeUsage Model string Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) } // GatewayService handles API gateway operations type GatewayService struct { repos *repository.Repositories rdb *redis.Client cfg *config.Config oauthService *OAuthService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService identityService *IdentityService httpClient *http.Client } // NewGatewayService creates a new GatewayService func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService { // 计算响应头超时时间 responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second if responseHeaderTimeout == 0 { responseHeaderTimeout = 300 * time.Second // 默认5分钟,LLM高负载时可能排队较久 } transport := &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时 // 注意:不设置整体 Timeout,让流式响应可以无限时间传输 } return &GatewayService{ repos: repos, rdb: rdb, cfg: cfg, oauthService: oauthService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, identityService: identityService, httpClient: &http.Client{ Transport: transport, // 不设置 Timeout:流式请求可能持续十几分钟 // 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头) }, } } // GenerateSessionHash 从请求体计算粘性会话hash func (s *GatewayService) GenerateSessionHash(body []byte) string { var req map[string]interface{} if err := json.Unmarshal(body, &req); err != nil { return "" } // 1. 最高优先级:从metadata.user_id提取session_xxx if metadata, ok := req["metadata"].(map[string]interface{}); ok { if userID, ok := metadata["user_id"].(string); ok { re := regexp.MustCompile(`session_([a-f0-9-]{36})`) if match := re.FindStringSubmatch(userID); len(match) > 1 { return match[1] } } } // 2. 提取带cache_control: {type: "ephemeral"}的内容 cacheableContent := s.extractCacheableContent(req) if cacheableContent != "" { return s.hashContent(cacheableContent) } // 3. Fallback: 使用system内容 if system := req["system"]; system != nil { systemText := s.extractTextFromSystem(system) if systemText != "" { return s.hashContent(systemText) } } // 4. 最后fallback: 使用第一条消息 if messages, ok := req["messages"].([]interface{}); ok && len(messages) > 0 { if firstMsg, ok := messages[0].(map[string]interface{}); ok { msgText := s.extractTextFromContent(firstMsg["content"]) if msgText != "" { return s.hashContent(msgText) } } } return "" } func (s *GatewayService) extractCacheableContent(req map[string]interface{}) string { var content string // 检查system中的cacheable内容 if system, ok := req["system"].([]interface{}); ok { for _, part := range system { if partMap, ok := part.(map[string]interface{}); ok { if cc, ok := partMap["cache_control"].(map[string]interface{}); ok { if cc["type"] == "ephemeral" { if text, ok := partMap["text"].(string); ok { content += text } } } } } } // 检查messages中的cacheable内容 if messages, ok := req["messages"].([]interface{}); ok { for _, msg := range messages { if msgMap, ok := msg.(map[string]interface{}); ok { if msgContent, ok := msgMap["content"].([]interface{}); ok { for _, part := range msgContent { if partMap, ok := part.(map[string]interface{}); ok { if cc, ok := partMap["cache_control"].(map[string]interface{}); ok { if cc["type"] == "ephemeral" { // 找到cacheable内容,提取第一条消息的文本 return s.extractTextFromContent(msgMap["content"]) } } } } } } } } return content } func (s *GatewayService) extractTextFromSystem(system interface{}) string { switch v := system.(type) { case string: return v case []interface{}: var texts []string for _, part := range v { if partMap, ok := part.(map[string]interface{}); ok { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } return strings.Join(texts, "") } return "" } func (s *GatewayService) extractTextFromContent(content interface{}) string { switch v := content.(type) { case string: return v case []interface{}: var texts []string for _, part := range v { if partMap, ok := part.(map[string]interface{}); ok { if partMap["type"] == "text" { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } } return strings.Join(texts, "") } return "" } func (s *GatewayService) hashContent(content string) string { hash := sha256.Sum256([]byte(content)) return hex.EncodeToString(hash[:16]) // 32字符 } // replaceModelInBody 替换请求体中的model字段 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { var req map[string]interface{} if err := json.Unmarshal(body, &req); err != nil { return body } req["model"] = newModel newBody, err := json.Marshal(req) if err != nil { return body } return newBody } // SelectAccount 选择账号(粘性会话+优先级) func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") } // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64() if err == nil && accountID > 0 { account, err := s.repos.Account.GetByID(ctx, accountID) // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 // 同时检查模型支持 if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // 续期粘性会话 s.rdb.Expire(ctx, stickySessionPrefix+sessionHash, stickySessionTTL) return account, nil } } } // 2. 获取可调度账号列表(排除限流和过载的账号) var accounts []model.Account var err error if groupID != nil { accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID) } else { accounts, err = s.repos.Account.ListSchedulable(ctx) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } // 3. 按优先级+最久未用选择(考虑模型支持) var selected *model.Account for i := range accounts { acc := &accounts[i] // 检查模型支持 if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } if selected == nil { selected = acc continue } // 优先选择priority值更小的(priority值越小优先级越高) if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { // 优先级相同时,选最久未用的 if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) { selected = acc } } } if selected == nil { if requestedModel != "" { return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) } return nil, errors.New("no available accounts") } // 4. 建立粘性绑定 if sessionHash != "" { s.rdb.Set(ctx, stickySessionPrefix+sessionHash, selected.ID, stickySessionTTL) } return selected, nil } // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { switch account.Type { case model.AccountTypeOAuth, model.AccountTypeSetupToken: // Both oauth and setup-token use OAuth token flow return s.getOAuthToken(ctx, account) case model.AccountTypeApiKey: apiKey := account.GetCredential("api_key") if apiKey == "" { return "", "", errors.New("api_key not found in credentials") } return apiKey, "apikey", nil default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } } func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) { accessToken := account.GetCredential("access_token") expiresAtStr := account.GetCredential("expires_at") // 检查是否需要刷新 needRefresh := false if expiresAtStr != "" { expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt { needRefresh = true } } if needRefresh || accessToken == "" { tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account) if err != nil { return "", "", fmt.Errorf("refresh token failed: %w", err) } // 更新账号凭证 account.Credentials["access_token"] = tokenInfo.AccessToken account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) if tokenInfo.RefreshToken != "" { account.Credentials["refresh_token"] = tokenInfo.RefreshToken } if err := s.repos.Account.Update(ctx, account); err != nil { log.Printf("Failed to update account credentials: %v", err) } return tokenInfo.AccessToken, "oauth", nil } return accessToken, "oauth", nil } // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { startTime := time.Now() // 解析请求获取model和stream var req struct { Model string `json:"model"` Stream bool `json:"stream"` } if err := json.Unmarshal(body, &req); err != nil { return nil, fmt.Errorf("parse request: %w", err) } // 应用模型映射(仅对apikey类型账号) originalModel := req.Model if account.Type == model.AccountTypeApiKey { mappedModel := account.GetMappedModel(req.Model) if mappedModel != req.Model { // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) req.Model = mappedModel log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) } } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { return nil, err } // 构建上游请求 upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) if err != nil { return nil, err } // 发送请求 resp, err := s.httpClient.Do(upstreamReq) if err != nil { return nil, fmt.Errorf("upstream request failed: %w", err) } defer resp.Body.Close() // 处理401错误:刷新token重试 if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" { resp.Body.Close() token, tokenType, err = s.forceRefreshToken(ctx, account) if err != nil { return nil, fmt.Errorf("token refresh failed: %w", err) } upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) if err != nil { return nil, err } resp, err = s.httpClient.Do(upstreamReq) if err != nil { return nil, fmt.Errorf("retry request failed: %w", err) } defer resp.Body.Close() } // 处理错误响应 if resp.StatusCode >= 400 { return s.handleErrorResponse(ctx, resp, c, account) } // 处理正常响应 var usage *ClaudeUsage var firstTokenMs *int if req.Stream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model) if err != nil { return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model) if err != nil { return nil, err } } return &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 Stream: req.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == model.AccountTypeApiKey { baseURL := account.GetBaseURL() targetURL = baseURL + "/v1/messages" } // OAuth账号:应用统一指纹 var fingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err != nil { log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers } else { fingerprint = fp // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } } } req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err } // 设置认证头 if tokenType == "oauth" { req.Header.Set("Authorization", "Bearer "+token) } else { req.Header.Set("x-api-key", token) } // 白名单透传headers for key, values := range c.Request.Header { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { req.Header.Add(key, v) } } } // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) if fingerprint != nil { s.identityService.ApplyFingerprint(req, fingerprint) } // 确保必要的headers存在 if req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "application/json") } if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) } // 配置代理 if account.ProxyID != nil && account.Proxy != nil { proxyURL := account.Proxy.URL() if proxyURL != "" { if parsedURL, err := url.Parse(proxyURL); err == nil { // 计算响应头超时时间(与默认 Transport 保持一致) responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second if responseHeaderTimeout == 0 { responseHeaderTimeout = 300 * time.Second } transport := &http.Transport{ Proxy: http.ProxyURL(parsedURL), MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, ResponseHeaderTimeout: responseHeaderTimeout, } s.httpClient.Transport = transport } } } return req, nil } // getBetaHeader 处理anthropic-beta header // 对于OAuth账号,需要确保包含oauth-2025-04-20 func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string { const oauthBeta = "oauth-2025-04-20" const claudeCodeBeta = "claude-code-20250219" // 如果客户端传了anthropic-beta if clientBetaHeader != "" { // 已包含oauth beta则直接返回 if strings.Contains(clientBetaHeader, oauthBeta) { return clientBetaHeader } // 需要添加oauth beta parts := strings.Split(clientBetaHeader, ",") for i, p := range parts { parts[i] = strings.TrimSpace(p) } // 在claude-code-20250219后面插入oauth beta claudeCodeIdx := -1 for i, p := range parts { if p == claudeCodeBeta { claudeCodeIdx = i break } } if claudeCodeIdx >= 0 { // 在claude-code后面插入 newParts := make([]string, 0, len(parts)+1) newParts = append(newParts, parts[:claudeCodeIdx+1]...) newParts = append(newParts, oauthBeta) newParts = append(newParts, parts[claudeCodeIdx+1:]...) return strings.Join(newParts, ",") } // 没有claude-code,放在第一位 return oauthBeta + "," + clientBetaHeader } // 客户端没传,根据模型生成 var modelID string var reqMap map[string]interface{} if json.Unmarshal(body, &reqMap) == nil { if m, ok := reqMap["model"].(string); ok { modelID = m } } // haiku模型不需要claude-code beta if strings.Contains(strings.ToLower(modelID), "haiku") { return "oauth-2025-04-20,interleaved-thinking-2025-05-14" } return "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14" } func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) { tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account) if err != nil { return "", "", err } account.Credentials["access_token"] = tokenInfo.AccessToken account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) if tokenInfo.RefreshToken != "" { account.Credentials["refresh_token"] = tokenInfo.RefreshToken } if err := s.repos.Account.Update(ctx, account); err != nil { log.Printf("Failed to update account credentials: %v", err) } return tokenInfo.AccessToken, "oauth", nil } func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) // apikey 类型账号:检查自定义错误码配置 // 如果启用且错误码不在列表中,返回通用 500 错误(不做任何账号状态处理) if !account.ShouldHandleErrorCode(resp.StatusCode) { c.JSON(http.StatusInternalServerError, gin.H{ "type": "error", "error": gin.H{ "type": "upstream_error", "message": "Upstream gateway error", }, }) return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) } // 处理上游错误,标记账号状态 s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) var errType, errMsg string var statusCode int switch resp.StatusCode { case 401: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream authentication failed, please contact administrator" case 403: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream access forbidden, please contact administrator" case 429: statusCode = http.StatusTooManyRequests errType = "rate_limit_error" errMsg = "Upstream rate limit exceeded, please retry later" case 529: statusCode = http.StatusServiceUnavailable errType = "overloaded_error" errMsg = "Upstream service overloaded, please retry later" case 500, 502, 503, 504: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream service temporarily unavailable" default: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream request failed" } // 返回自定义错误响应 c.JSON(statusCode, gin.H{ "type": "error", "error": gin.H{ "type": errType, "message": errMsg, }, }) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) } // streamingResult 流式响应结果 type streamingResult struct { usage *ClaudeUsage firstTokenMs *int } func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) // 设置SSE响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // 透传其他响应头 if v := resp.Header.Get("x-request-id"); v != "" { c.Header("x-request-id", v) } w := c.Writer flusher, ok := w.(http.Flusher) if !ok { return nil, errors.New("streaming not supported") } usage := &ClaudeUsage{} var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) // 设置更大的buffer以处理长行 scanner.Buffer(make([]byte, 64*1024), 1024*1024) needModelReplace := originalModel != mappedModel for scanner.Scan() { line := scanner.Text() // 如果有模型映射,替换响应中的model字段 if needModelReplace && strings.HasPrefix(line, "data: ") { line = s.replaceModelInSSELine(line, mappedModel, originalModel) } // 转发行 fmt.Fprintf(w, "%s\n", line) flusher.Flush() // 解析usage数据 if strings.HasPrefix(line, "data: ") { data := line[6:] // 记录首字时间:第一个有效的 content_block_delta 或 message_start if firstTokenMs == nil && data != "" && data != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsage(data, usage) } } if err := scanner.Err(); err != nil { return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } // replaceModelInSSELine 替换SSE数据行中的model字段 func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { data := line[6:] // 去掉 "data: " 前缀 if data == "" || data == "[DONE]" { return line } var event map[string]interface{} if err := json.Unmarshal([]byte(data), &event); err != nil { return line } // 只替换 message_start 事件中的 message.model if event["type"] != "message_start" { return line } msg, ok := event["message"].(map[string]interface{}) if !ok { return line } model, ok := msg["model"].(string) if !ok || model != fromModel { return line } msg["model"] = toModel newData, err := json.Marshal(event) if err != nil { return line } return "data: " + string(newData) } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { // 解析message_start获取input tokens var msgStart struct { Type string `json:"type"` Message struct { Usage ClaudeUsage `json:"usage"` } `json:"message"` } if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { usage.InputTokens = msgStart.Message.Usage.InputTokens usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens } // 解析message_delta获取output tokens var msgDelta struct { Type string `json:"type"` Usage struct { OutputTokens int `json:"output_tokens"` } `json:"usage"` } if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { usage.OutputTokens = msgDelta.Usage.OutputTokens } } func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 解析usage var response struct { Usage ClaudeUsage `json:"usage"` } if err := json.Unmarshal(body, &response); err != nil { return nil, fmt.Errorf("parse response: %w", err) } // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } // 透传响应头 for key, values := range resp.Header { for _, value := range values { c.Header(key, value) } } // 写入响应 c.Data(resp.StatusCode, "application/json", body) return &response.Usage, nil } // replaceModelInResponseBody 替换响应体中的model字段 func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { var resp map[string]interface{} if err := json.Unmarshal(body, &resp); err != nil { return body } model, ok := resp["model"].(string) if !ok || model != fromModel { return body } resp["model"] = toModel newBody, err := json.Marshal(resp) if err != nil { return body } return newBody } // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult ApiKey *model.ApiKey User *model.User Account *model.Account Subscription *model.UserSubscription // 可选:订阅信息 } // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result apiKey := input.ApiKey user := input.User account := input.Account subscription := input.Subscription // 计算费用 tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, } // 获取费率倍数 multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier } cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { log.Printf("Calculate cost failed: %v", err) // 使用默认费用继续 cost = &CostBreakdown{ActualCost: 0} } // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() billingType := model.BillingTypeBalance if isSubscriptionBilling { billingType = model.BillingTypeSubscription } // 创建使用日志 durationMs := int(result.Duration.Milliseconds()) usageLog := &model.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, CacheReadCost: cost.CacheReadCost, TotalCost: cost.TotalCost, ActualCost: cost.ActualCost, RateMultiplier: multiplier, BillingType: billingType, Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), } // 添加分组和订阅关联 if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID } if subscription != nil { usageLog.SubscriptionID = &subscription.ID } if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil { log.Printf("Create usage log failed: %v", err) } // 根据计费类型执行扣费 if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if cost.TotalCost > 0 { if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { log.Printf("Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil { log.Printf("Update subscription cache failed: %v", err) } }() } } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if cost.ActualCost > 0 { if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { log.Printf("Deduct balance failed: %v", err) } // 异步更新余额缓存 go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil { log.Printf("Update balance cache failed: %v", err) } }() } } // 更新账号最后使用时间 if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil { log.Printf("Update last used failed: %v", err) } return nil }