diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index e1b1b9a8..b9285c04 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 验证 model 必填 if reqModel == "" { @@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false))) // 获取订阅信息(可能为nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go new file mode 100644 index 00000000..da376036 --- /dev/null +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -0,0 +1,289 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups. +// POST /v1/chat/completions +// This converts Chat Completions requests to Anthropic format (via Responses format chain), +// forwards to Anthropic upstream, and converts responses back to Chat Completions format. +func (h *GatewayHandler) ChatCompletions(c *gin.Context) { + streamStarted := false + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read request body + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // Validate JSON + if !gjson.ValidBytes(body) { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // Extract model and stream + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + + // Claude Code only restriction + if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { + h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", + "This group is restricted to Claude Code clients (/v1/messages only)") + return + } + + // Error passthrough binding + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + // 1. Acquire user concurrency slot + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2. Re-check billing + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.chatCompletionsErrorResponse(c, status, code, message) + return + } + + // Parse request for session hash + parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions") + if parsedReq == nil { + parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body} + } + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + + // 3. Account selection + failover loop + fs := NewFailoverState(h.maxAccountSwitches, false) + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + if err != nil { + if len(fs.FailedAccountIDs) == 0 { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + continue + case FailoverCanceled: + return + default: + if fs.LastFailoverErr != nil { + h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + } else { + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted") + } + return + } + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + + // 4. Acquire account concurrency slot + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 5. Forward request + writerSizeBeforeForward := c.Writer.Size() + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) + + if accountReleaseFunc != nil { + accountReleaseFunc() + } + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleCCFailoverExhausted(c, failoverErr, true) + return + } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + return + case FailoverCanceled: + return + } + } + h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.cc.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + // 6. Record usage + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("gateway.cc.record_usage_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + } + }) + return + } +} + +// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format. +func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// handleCCFailoverExhausted writes a failover-exhausted error in CC format. +func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) { + if streamStarted { + return + } + statusCode := http.StatusBadGateway + if lastErr != nil && lastErr.StatusCode > 0 { + statusCode = lastErr.StatusCode + } + h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") +} diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go new file mode 100644 index 00000000..d146d724 --- /dev/null +++ b/backend/internal/handler/gateway_handler_responses.go @@ -0,0 +1,295 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// Responses handles OpenAI Responses API endpoint for Anthropic platform groups. +// POST /v1/responses +// This converts Responses API requests to Anthropic format, forwards to Anthropic +// upstream, and converts responses back to Responses format. +func (h *GatewayHandler) Responses(c *gin.Context) { + streamStarted := false + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.gateway.responses", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read request body + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // Validate JSON + if !gjson.ValidBytes(body) { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // Extract model and stream using gjson (like OpenAI handler) + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + + // Claude Code only restriction: + // /v1/responses is never a Claude Code endpoint. + // When claude_code_only is enabled, this endpoint is rejected. + // The existing service-layer checkClaudeCodeRestriction handles degradation + // to fallback groups when the Forward path calls SelectAccountForModelWithExclusions. + // Here we just reject at handler level since /v1/responses clients can't be Claude Code. + if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { + h.responsesErrorResponse(c, http.StatusForbidden, "permission_error", + "This group is restricted to Claude Code clients (/v1/messages only)") + return + } + + // Error passthrough binding + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + // 1. Acquire user concurrency slot + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2. Re-check billing + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.responsesErrorResponse(c, status, code, message) + return + } + + // Parse request for session hash + parsedReq, _ := service.ParseGatewayRequest(body, "responses") + if parsedReq == nil { + parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body} + } + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + + // 3. Account selection + failover loop + fs := NewFailoverState(h.maxAccountSwitches, false) + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + if err != nil { + if len(fs.FailedAccountIDs) == 0 { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + continue + case FailoverCanceled: + return + default: + if fs.LastFailoverErr != nil { + h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + } else { + h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted") + } + return + } + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + + // 4. Acquire account concurrency slot + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 5. Forward request + writerSizeBeforeForward := c.Writer.Size() + result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) + + if accountReleaseFunc != nil { + accountReleaseFunc() + } + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + // Can't failover if streaming content already sent + if c.Writer.Size() != writerSizeBeforeForward { + h.handleResponsesFailoverExhausted(c, failoverErr, true) + return + } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + return + case FailoverCanceled: + return + } + } + h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.responses.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + // 6. Record usage + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("gateway.responses.record_usage_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + } + }) + return + } +} + +// responsesErrorResponse writes an error in OpenAI Responses API format. +func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": code, + "message": message, + }, + }) +} + +// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format. +func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) { + if streamStarted { + return // Can't write error after stream started + } + statusCode := http.StatusBadGateway + if lastErr != nil && lastErr.StatusCode > 0 { + statusCode = lastErr.StatusCode + } + h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index fb231898..5dc03b6d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } setOpsRequestContext(c, modelName, stream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index dd158d8b..0c94aa21 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index b7f18d21..3ce6e5d6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { @@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { @@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { zap.String("previous_response_id_kind", previousResponseIDKind), ) setOpsRequestContext(c, reqModel, true, firstMessage) + setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) var currentUserRelease func() var currentAccountRelease func() diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index ceb06f0e..90e90dd0 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -27,6 +27,9 @@ const ( opsRequestBodyKey = "ops_request_body" opsAccountIDKey = "ops_account_id" + opsUpstreamModelKey = "ops_upstream_model" + opsRequestTypeKey = "ops_request_type" + // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 opsErrContextCanceled = "context canceled" opsErrNoAvailableAccounts = "no available accounts" @@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody } } +// setOpsEndpointContext stores upstream model and request type for ops error logging. +// Called by handlers after model mapping and request type determination. +func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) { + if c == nil { + return + } + if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" { + c.Set(opsUpstreamModelKey, upstreamModel) + } + c.Set(opsRequestTypeKey, requestType) +} + func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { if c == nil || entry == nil { return @@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } return "" }(), - Stream: stream, + Stream: stream, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, platform), + RequestedModel: modelName, + UpstreamModel: func() string { + if v, ok := c.Get(opsUpstreamModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" + }(), + RequestType: func() *int16 { + if v, ok := c.Get(opsRequestTypeKey); ok { + switch t := v.(type) { + case int16: + return &t + case int: + v16 := int16(t) + return &v16 + } + } + return nil + }(), UserAgent: c.GetHeader("User-Agent"), ErrorPhase: "upstream", @@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } return "" }(), - Stream: stream, + Stream: stream, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, platform), + RequestedModel: modelName, + UpstreamModel: func() string { + if v, ok := c.Get(opsUpstreamModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" + }(), + RequestType: func() *int16 { + if v, ok := c.Get(opsRequestTypeKey); ok { + switch t := v.(type) { + case int16: + return &t + case int: + v16 := int16(t) + return &v16 + } + } + return nil + }(), UserAgent: c.GetHeader("User-Agent"), ErrorPhase: phase, diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 679dd4ce..6ae45110 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) { }) } } + +func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream + + v, ok := c.Get(opsUpstreamModelKey) + require.True(t, ok) + vStr, ok := v.(string) + require.True(t, ok) + require.Equal(t, "claude-3-5-sonnet-20241022", vStr) + + rt, ok := c.Get(opsRequestTypeKey) + require.True(t, ok) + rtVal, ok := rt.(int16) + require.True(t, ok) + require.Equal(t, int16(2), rtVal) +} + +func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + setOpsEndpointContext(c, "", int16(1)) + + _, ok := c.Get(opsUpstreamModelKey) + require.False(t, ok, "empty upstream model should not be stored") + + rt, ok := c.Get(opsRequestTypeKey) + require.True(t, ok) + rtVal, ok := rt.(int16) + require.True(t, ok) + require.Equal(t, int16(1), rtVal) +} + +func TestSetOpsEndpointContext_NilContext(t *testing.T) { + require.NotPanics(t, func() { + setOpsEndpointContext(nil, "model", int16(1)) + }) +} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index cc1b1c0b..5e505409 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { } setOpsRequestContext(c, reqModel, clientStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false))) platform := "" if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses_response.go b/backend/internal/pkg/apicompat/anthropic_to_responses_response.go new file mode 100644 index 00000000..9290e399 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_to_responses_response.go @@ -0,0 +1,521 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: AnthropicResponse → ResponsesResponse +// --------------------------------------------------------------------------- + +// AnthropicToResponsesResponse converts an Anthropic Messages response into a +// Responses API response. This is the reverse of ResponsesToAnthropic and +// enables Anthropic upstream responses to be returned in OpenAI Responses format. +func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse { + id := resp.ID + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: resp.Model, + } + + var outputs []ResponsesOutput + var msgParts []ResponsesContentPart + + for _, block := range resp.Content { + switch block.Type { + case "thinking": + if block.Thinking != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: block.Thinking, + }}, + }) + } + case "text": + if block.Text != "" { + msgParts = append(msgParts, ResponsesContentPart{ + Type: "output_text", + Text: block.Text, + }) + } + case "tool_use": + args := "{}" + if len(block.Input) > 0 { + args = string(block.Input) + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toResponsesCallID(block.ID), + Name: block.Name, + Arguments: args, + Status: "completed", + }) + } + } + + // Assemble message output item from text parts + if len(msgParts) > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: msgParts, + Status: "completed", + }) + } + + if len(outputs) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + }) + } + out.Output = outputs + + // Map stop_reason → status + out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content) + if out.Status == "incomplete" { + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + // Usage + out.Usage = &ResponsesUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.CacheReadInputTokens > 0 { + out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: resp.Usage.CacheReadInputTokens, + } + } + + return out +} + +// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status. +func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string { + switch stopReason { + case "max_tokens": + return "incomplete" + case "end_turn", "tool_use", "stop_sequence": + return "completed" + default: + return "completed" + } +} + +// --------------------------------------------------------------------------- +// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter) +// --------------------------------------------------------------------------- + +// AnthropicEventToResponsesState tracks state for converting a sequence of +// Anthropic SSE events into Responses SSE events. +type AnthropicEventToResponsesState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + + // CreatedSent tracks whether response.created has been emitted. + CreatedSent bool + // CompletedSent tracks whether the terminal event has been emitted. + CompletedSent bool + + // Current output tracking + OutputIndex int + CurrentItemID string + CurrentItemType string // "message" | "function_call" | "reasoning" + + // For message output: accumulate text parts + ContentIndex int + + // For function_call: track per-output info + CurrentCallID string + CurrentName string + + // Usage from message_delta + InputTokens int + OutputTokens int + CacheReadInputTokens int +} + +// NewAnthropicEventToResponsesState returns an initialised stream state. +func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState { + return &AnthropicEventToResponsesState{ + Created: time.Now().Unix(), + } +} + +// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into +// zero or more Responses SSE events, updating state as it goes. +func AnthropicEventToResponsesEvents( + evt *AnthropicStreamEvent, + state *AnthropicEventToResponsesState, +) []ResponsesStreamEvent { + switch evt.Type { + case "message_start": + return anthToResHandleMessageStart(evt, state) + case "content_block_start": + return anthToResHandleContentBlockStart(evt, state) + case "content_block_delta": + return anthToResHandleContentBlockDelta(evt, state) + case "content_block_stop": + return anthToResHandleContentBlockStop(evt, state) + case "message_delta": + return anthToResHandleMessageDelta(evt, state) + case "message_stop": + return anthToResHandleMessageStop(state) + default: + return nil + } +} + +// FinalizeAnthropicResponsesStream emits synthetic termination events if the +// stream ended without a proper message_stop. +func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if !state.CreatedSent || state.CompletedSent { + return nil + } + + var events []ResponsesStreamEvent + + // Close any open item + events = append(events, closeCurrentResponsesItem(state)...) + + // Emit response.completed + events = append(events, makeResponsesCompletedEvent(state, "completed", nil)) + state.CompletedSent = true + return events +} + +// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line. +func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) { + data, err := json.Marshal(evt) + if err != nil { + return "", err + } + return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil +} + +// --- internal handlers --- + +func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.Message != nil { + state.ResponseID = evt.Message.ID + if state.Model == "" { + state.Model = evt.Message.Model + } + if evt.Message.Usage.InputTokens > 0 { + state.InputTokens = evt.Message.Usage.InputTokens + } + } + + if state.CreatedSent { + return nil + } + state.CreatedSent = true + + // Emit response.created + return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)} +} + +func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.ContentBlock == nil { + return nil + } + + var events []ResponsesStreamEvent + + switch evt.ContentBlock.Type { + case "thinking": + state.CurrentItemID = generateItemID() + state.CurrentItemType = "reasoning" + state.ContentIndex = 0 + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "reasoning", + ID: state.CurrentItemID, + }, + })) + + case "text": + // If we don't have an open message item, open one + if state.CurrentItemType != "message" { + state.CurrentItemID = generateItemID() + state.CurrentItemType = "message" + state.ContentIndex = 0 + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "message", + ID: state.CurrentItemID, + Role: "assistant", + Status: "in_progress", + }, + })) + } + + case "tool_use": + // Close previous item if any + events = append(events, closeCurrentResponsesItem(state)...) + + state.CurrentItemID = generateItemID() + state.CurrentItemType = "function_call" + state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID) + state.CurrentName = evt.ContentBlock.Name + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "function_call", + ID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + Status: "in_progress", + }, + })) + } + + return events +} + +func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.Delta == nil { + return nil + } + + switch evt.Delta.Type { + case "text_delta": + if evt.Delta.Text == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ContentIndex: state.ContentIndex, + Delta: evt.Delta.Text, + ItemID: state.CurrentItemID, + })} + + case "thinking_delta": + if evt.Delta.Thinking == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + SummaryIndex: 0, + Delta: evt.Delta.Thinking, + ItemID: state.CurrentItemID, + })} + + case "input_json_delta": + if evt.Delta.PartialJSON == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Delta: evt.Delta.PartialJSON, + ItemID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + })} + + case "signature_delta": + // Anthropic signature deltas have no Responses equivalent; skip + return nil + } + + return nil +} + +func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + switch state.CurrentItemType { + case "reasoning": + // Emit reasoning summary done + output item done + events := []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + SummaryIndex: 0, + ItemID: state.CurrentItemID, + }), + } + events = append(events, closeCurrentResponsesItem(state)...) + return events + + case "function_call": + // Emit function_call_arguments.done + output item done + events := []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ItemID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + }), + } + events = append(events, closeCurrentResponsesItem(state)...) + return events + + case "message": + // Emit output_text.done (text block is done, but message item stays open for potential more blocks) + return []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ContentIndex: state.ContentIndex, + ItemID: state.CurrentItemID, + }), + } + } + + return nil +} + +func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + // Update usage + if evt.Usage != nil { + state.OutputTokens = evt.Usage.OutputTokens + if evt.Usage.CacheReadInputTokens > 0 { + state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens + } + } + + return nil +} + +func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if state.CompletedSent { + return nil + } + + var events []ResponsesStreamEvent + + // Close any open item + events = append(events, closeCurrentResponsesItem(state)...) + + // Determine status + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + + // Emit response.completed + events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails)) + state.CompletedSent = true + return events +} + +// --- helper functions --- + +func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if state.CurrentItemType == "" { + return nil + } + + itemType := state.CurrentItemType + itemID := state.CurrentItemID + + // Reset + state.CurrentItemType = "" + state.CurrentItemID = "" + state.CurrentCallID = "" + state.CurrentName = "" + state.OutputIndex++ + state.ContentIndex = 0 + + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex - 1, // Use the index before increment + Item: &ResponsesOutput{ + Type: itemType, + ID: itemID, + Status: "completed", + }, + })} +} + +func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + return ResponsesStreamEvent{ + Type: "response.created", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + } +} + +func makeResponsesCompletedEvent( + state *AnthropicEventToResponsesState, + status string, + incompleteDetails *ResponsesIncompleteDetails, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + usage := &ResponsesUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + TotalTokens: state.InputTokens + state.OutputTokens, + } + if state.CacheReadInputTokens > 0 { + usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: state.CacheReadInputTokens, + } + } + + return ResponsesStreamEvent{ + Type: "response.completed", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity + Usage: usage, + IncompleteDetails: incompleteDetails, + }, + } +} + +func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func generateResponsesID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "resp_" + hex.EncodeToString(b) +} + +func generateItemID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "item_" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go new file mode 100644 index 00000000..f0a5b07e --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go @@ -0,0 +1,464 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ResponsesToAnthropicRequest converts a Responses API request into an +// Anthropic Messages request. This is the reverse of AnthropicToResponses and +// enables Anthropic platform groups to accept OpenAI Responses API requests +// by converting them to the native /v1/messages format before forwarding upstream. +func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) { + system, messages, err := convertResponsesInputToAnthropic(req.Input) + if err != nil { + return nil, err + } + + out := &AnthropicRequest{ + Model: req.Model, + Messages: messages, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + } + + if len(system) > 0 { + out.System = system + } + + // max_output_tokens → max_tokens + if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 { + out.MaxTokens = *req.MaxOutputTokens + } + if out.MaxTokens == 0 { + // Anthropic requires max_tokens; default to a sensible value. + out.MaxTokens = 8192 + } + + // Convert tools + if len(req.Tools) > 0 { + out.Tools = convertResponsesToAnthropicTools(req.Tools) + } + + // Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses) + if len(req.ToolChoice) > 0 { + tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice) + if err != nil { + return nil, fmt.Errorf("convert tool_choice: %w", err) + } + out.ToolChoice = tc + } + + // reasoning.effort → output_config.effort + thinking + if req.Reasoning != nil && req.Reasoning.Effort != "" { + effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort) + out.OutputConfig = &AnthropicOutputConfig{Effort: effort} + // Enable thinking for non-low efforts + if effort != "low" { + out.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: defaultThinkingBudget(effort), + } + } + } + + return out, nil +} + +// defaultThinkingBudget returns a sensible thinking budget based on effort level. +func defaultThinkingBudget(effort string) int { + switch effort { + case "low": + return 1024 + case "medium": + return 4096 + case "high": + return 10240 + case "max": + return 32768 + default: + return 10240 + } +} + +// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to +// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses. +// +// low → low +// medium → medium +// high → high +// xhigh → max +func mapResponsesEffortToAnthropic(effort string) string { + if effort == "xhigh" { + return "max" + } + return effort // low→low, medium→medium, high→high, unknown→passthrough +} + +// convertResponsesInputToAnthropic extracts system prompt and messages from +// a Responses API input array. Returns the system as raw JSON (for Anthropic's +// polymorphic system field) and a list of Anthropic messages. +func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) { + // Try as plain string input. + var inputStr string + if err := json.Unmarshal(inputRaw, &inputStr); err == nil { + content, _ := json.Marshal(inputStr) + return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil + } + + var items []ResponsesInputItem + if err := json.Unmarshal(inputRaw, &items); err != nil { + return nil, nil, fmt.Errorf("parse responses input: %w", err) + } + + var system json.RawMessage + var messages []AnthropicMessage + + for _, item := range items { + switch { + case item.Role == "system": + // System prompt → Anthropic system field + text := extractTextFromContent(item.Content) + if text != "" { + system, _ = json.Marshal(text) + } + + case item.Type == "function_call": + // function_call → assistant message with tool_use block + input := json.RawMessage("{}") + if item.Arguments != "" { + input = json.RawMessage(item.Arguments) + } + block := AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallIDToAnthropic(item.CallID), + Name: item.Name, + Input: input, + } + blockJSON, _ := json.Marshal([]AnthropicContentBlock{block}) + messages = append(messages, AnthropicMessage{ + Role: "assistant", + Content: blockJSON, + }) + + case item.Type == "function_call_output": + // function_call_output → user message with tool_result block + outputContent := item.Output + if outputContent == "" { + outputContent = "(empty)" + } + contentJSON, _ := json.Marshal(outputContent) + block := AnthropicContentBlock{ + Type: "tool_result", + ToolUseID: fromResponsesCallIDToAnthropic(item.CallID), + Content: contentJSON, + } + blockJSON, _ := json.Marshal([]AnthropicContentBlock{block}) + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: blockJSON, + }) + + case item.Role == "user": + content, err := convertResponsesUserToAnthropicContent(item.Content) + if err != nil { + return nil, nil, err + } + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: content, + }) + + case item.Role == "assistant": + content, err := convertResponsesAssistantToAnthropicContent(item.Content) + if err != nil { + return nil, nil, err + } + messages = append(messages, AnthropicMessage{ + Role: "assistant", + Content: content, + }) + + default: + // Unknown role/type — attempt as user message + if item.Content != nil { + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: item.Content, + }) + } + } + } + + // Merge consecutive same-role messages (Anthropic requires alternating roles) + messages = mergeConsecutiveMessages(messages) + + return system, messages, nil +} + +// extractTextFromContent extracts text from a content field that may be a +// plain string or an array of content parts. +func extractTextFromContent(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, p := range parts { + if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" { + texts = append(texts, p.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// convertResponsesUserToAnthropicContent converts a Responses user message +// content field into Anthropic content blocks JSON. +func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return json.Marshal("") // empty string content + } + + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Array of content parts → Anthropic content blocks. + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + // Pass through as-is if we can't parse + return raw, nil + } + + var blocks []AnthropicContentBlock + for _, p := range parts { + switch p.Type { + case "input_text", "text": + if p.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: p.Text, + }) + } + case "input_image": + src := dataURIToAnthropicImageSource(p.ImageURL) + if src != nil { + blocks = append(blocks, AnthropicContentBlock{ + Type: "image", + Source: src, + }) + } + } + } + + if len(blocks) == 0 { + return json.Marshal("") + } + return json.Marshal(blocks) +} + +// convertResponsesAssistantToAnthropicContent converts a Responses assistant +// message content field into Anthropic content blocks JSON. +func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}}) + } + + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}}) + } + + // Array of content parts → Anthropic content blocks. + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + return raw, nil + } + + var blocks []AnthropicContentBlock + for _, p := range parts { + switch p.Type { + case "output_text", "text": + if p.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: p.Text, + }) + } + } + } + + if len(blocks) == 0 { + blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""}) + } + return json.Marshal(blocks) +} + +// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to +// Anthropic format. Reverses toResponsesCallID. +func fromResponsesCallIDToAnthropic(id string) string { + // If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it + if after, ok := strings.CutPrefix(id, "fc_"); ok { + if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") { + return after + } + } + // Generate a synthetic Anthropic tool ID + if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") { + return "toolu_" + id + } + return id +} + +// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource. +func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource { + if !strings.HasPrefix(dataURI, "data:") { + return nil + } + // Format: data:;base64, + rest := strings.TrimPrefix(dataURI, "data:") + semicolonIdx := strings.Index(rest, ";") + if semicolonIdx < 0 { + return nil + } + mediaType := rest[:semicolonIdx] + rest = rest[semicolonIdx+1:] + if !strings.HasPrefix(rest, "base64,") { + return nil + } + data := strings.TrimPrefix(rest, "base64,") + return &AnthropicImageSource{ + Type: "base64", + MediaType: mediaType, + Data: data, + } +} + +// mergeConsecutiveMessages merges consecutive messages with the same role +// because Anthropic requires alternating user/assistant turns. +func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage { + if len(messages) <= 1 { + return messages + } + + var merged []AnthropicMessage + for _, msg := range messages { + if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role { + merged = append(merged, msg) + continue + } + + // Same role — merge content arrays + last := &merged[len(merged)-1] + lastBlocks := parseContentBlocks(last.Content) + newBlocks := parseContentBlocks(msg.Content) + combined := append(lastBlocks, newBlocks...) + last.Content, _ = json.Marshal(combined) + } + return merged +} + +// parseContentBlocks attempts to parse content as []AnthropicContentBlock. +// If it's a string, wraps it in a text block. +func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock { + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err == nil { + return blocks + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return []AnthropicContentBlock{{Type: "text", Text: s}} + } + return nil +} + +// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format. +// Reverse of convertAnthropicToolsToResponses. +func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool { + var out []AnthropicTool + for _, t := range tools { + switch t.Type { + case "web_search": + out = append(out, AnthropicTool{ + Type: "web_search_20250305", + Name: "web_search", + }) + case "function": + out = append(out, AnthropicTool{ + Name: t.Name, + Description: t.Description, + InputSchema: normalizeAnthropicInputSchema(t.Parameters), + }) + default: + // Pass through unknown tool types + out = append(out, AnthropicTool{ + Type: t.Type, + Name: t.Name, + Description: t.Description, + InputSchema: t.Parameters, + }) + } + } + return out +} + +// normalizeAnthropicInputSchema ensures the input_schema has a "type" field. +func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage { + if len(schema) == 0 || string(schema) == "null" { + return json.RawMessage(`{"type":"object","properties":{}}`) + } + return schema +} + +// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format. +// Reverse of convertAnthropicToolChoiceToResponses. +// +// "auto" → {"type":"auto"} +// "required" → {"type":"any"} +// "none" → {"type":"none"} +// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} +func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try as string first + var s string + if err := json.Unmarshal(raw, &s); err == nil { + switch s { + case "auto": + return json.Marshal(map[string]string{"type": "auto"}) + case "required": + return json.Marshal(map[string]string{"type": "any"}) + case "none": + return json.Marshal(map[string]string{"type": "none"}) + default: + return raw, nil + } + } + + // Try as object with type=function + var tc struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { + return json.Marshal(map[string]string{ + "type": "tool", + "name": tc.Function.Name, + }) + } + + // Pass through unknown + return raw, nil +} diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 02ca1a3b..5154b269 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -29,6 +29,11 @@ INSERT INTO ops_error_logs ( model, request_path, stream, + inbound_endpoint, + upstream_endpoint, + requested_model, + upstream_model, + request_type, user_agent, error_phase, error_type, @@ -57,7 +62,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43 )` func NewOpsRepository(db *sql.DB) service.OpsRepository { @@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { opsNullString(input.Model), opsNullString(input.RequestPath), input.Stream, + opsNullString(input.InboundEndpoint), + opsNullString(input.UpstreamEndpoint), + opsNullString(input.RequestedModel), + opsNullString(input.UpstreamModel), + opsNullInt16(input.RequestType), opsNullString(input.UserAgent), input.ErrorPhase, input.ErrorType, @@ -231,7 +241,12 @@ SELECT COALESCE(g.name, ''), CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, COALESCE(e.request_path, ''), - e.stream + e.stream, + COALESCE(e.inbound_endpoint, ''), + COALESCE(e.upstream_endpoint, ''), + COALESCE(e.requested_model, ''), + COALESCE(e.upstream_model, ''), + e.request_type FROM ops_error_logs e LEFT JOIN accounts a ON e.account_id = a.id LEFT JOIN groups g ON e.group_id = g.id @@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) var resolvedBy sql.NullInt64 var resolvedByName string var resolvedRetryID sql.NullInt64 + var requestType sql.NullInt64 if err := rows.Scan( &item.ID, &item.CreatedAt, @@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) &clientIP, &item.RequestPath, &item.Stream, + &item.InboundEndpoint, + &item.UpstreamEndpoint, + &item.RequestedModel, + &item.UpstreamModel, + &requestType, ); err != nil { return nil, err } @@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) item.GroupID = &v } item.GroupName = groupName + if requestType.Valid { + v := int16(requestType.Int64) + item.RequestType = &v + } out = append(out, &item) } if err := rows.Err(); err != nil { @@ -393,6 +418,11 @@ SELECT CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, COALESCE(e.request_path, ''), e.stream, + COALESCE(e.inbound_endpoint, ''), + COALESCE(e.upstream_endpoint, ''), + COALESCE(e.requested_model, ''), + COALESCE(e.upstream_model, ''), + e.request_type, COALESCE(e.user_agent, ''), e.auth_latency_ms, e.routing_latency_ms, @@ -427,6 +457,7 @@ LIMIT 1` var responseLatency sql.NullInt64 var ttft sql.NullInt64 var requestBodyBytes sql.NullInt64 + var requestType sql.NullInt64 err := r.db.QueryRowContext(ctx, q, id).Scan( &out.ID, @@ -464,6 +495,11 @@ LIMIT 1` &clientIP, &out.RequestPath, &out.Stream, + &out.InboundEndpoint, + &out.UpstreamEndpoint, + &out.RequestedModel, + &out.UpstreamModel, + &requestType, &out.UserAgent, &authLatency, &routingLatency, @@ -540,6 +576,10 @@ LIMIT 1` v := int(requestBodyBytes.Int64) out.RequestBodyBytes = &v } + if requestType.Valid { + v := int16(requestType.Int64) + out.RequestType = &v + } // Normalize request_body to empty string when stored as JSON null. out.RequestBody = strings.TrimSpace(out.RequestBody) @@ -1479,3 +1519,10 @@ func opsNullInt(v any) any { return sql.NullInt64{} } } + +func opsNullInt16(v *int16) any { + if v == nil { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: int64(*v), Valid: true} +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index fe820830..072cfdee 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -69,12 +69,30 @@ func RegisterGatewayRoutes( }) gateway.GET("/models", h.Gateway.Models) gateway.GET("/usage", h.Gateway.Usage) - // OpenAI Responses API - gateway.POST("/responses", h.OpenAIGateway.Responses) - gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) + // OpenAI Responses API: auto-route based on group platform + gateway.POST("/responses", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + }) + gateway.POST("/responses/*subpath", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + }) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) - // OpenAI Chat Completions API - gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) + // OpenAI Chat Completions API: auto-route based on group platform + gateway.POST("/chat/completions", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.ChatCompletions(c) + return + } + h.Gateway.ChatCompletions(c) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -92,12 +110,25 @@ func RegisterGatewayRoutes( gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } - // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) - r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + // OpenAI Responses API(不带v1前缀的别名)— auto-route based on group platform + responsesHandler := func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + } + r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) - // OpenAI Chat Completions API(不带v1前缀的别名) - r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) + // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.ChatCompletions(c) + return + } + h.Gateway.ChatCompletions(c) + }) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 6ee8280c..aa5d948c 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -643,6 +643,7 @@ urlFallbackLoop: AccountID: p.account.ID, AccountName: p.account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -720,6 +721,7 @@ urlFallbackLoop: AccountName: p.account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: upstreamMsg, Detail: getUpstreamDetail(respBody), @@ -754,6 +756,7 @@ urlFallbackLoop: AccountName: p.account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: upstreamMsg, Detail: getUpstreamDetail(respBody), diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go new file mode 100644 index 00000000..d3c611e2 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -0,0 +1,485 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body, +// converts it to Anthropic Messages format (chained via Responses format), +// forwards to the Anthropic upstream, and converts the response back to Chat +// Completions format. This enables Chat Completions clients to access Anthropic +// models through Anthropic platform groups. +func (s *GatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var ccReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &ccReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := ccReq.Model + clientStream := ccReq.Stream + includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage + + // 2. Convert CC → Responses → Anthropic (chained conversion) + responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq) + if err != nil { + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + + // 3. Force upstream streaming + anthropicReq.Stream = true + reqStream := true + + // 4. Model mapping + mappedModel := originalModel + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(originalModel) + if normalized != originalModel { + mappedModel = normalized + } + } + anthropicReq.Model = mappedModel + + logger.L().Debug("gateway forward_as_chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("client_stream", clientStream), + ) + + // 5. Marshal Anthropic request body + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + // 6. Apply Claude Code mimicry for OAuth accounts + isClaudeCode := false // CC API is never Claude Code + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + if !strings.Contains(strings.ToLower(mappedModel), "haiku") && + !systemIncludesClaudeCodePrompt(anthropicReq.System) { + anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) + } + } + + // 7. Enforce cache_control block limit + anthropicBody = enforceCacheControlLimit(anthropicBody) + + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 12. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 13. Extract reasoning effort from CC request body + reasoningEffort := extractCCReasoningEffortFromBody(body) + + // 14. Handle normal response + // Read Anthropic SSE → convert to Responses events → convert to CC format + var result *ForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage) + } else { + result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } + + return result, handleErr +} + +// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions +// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort) +// formats used by OpenAI-compatible clients. +func extractCCReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full +// response, then converts Anthropic → Responses → Chat Completions. +func (s *GatewayService) handleCCBufferedFromAnthropic( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResp *apicompat.AnthropicResponse + var usage ClaudeUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + // message_start carries the initial response structure and cache usage + if event.Type == "message_start" && event.Message != nil { + finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // message_delta carries final usage and stop_reason + if event.Type == "message_delta" { + if event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { + finalResp.StopReason = event.Delta.StopReason + } + } + if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil { + finalResp.Content = append(finalResp.Content, *event.ContentBlock) + } + if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil { + idx := *event.Index + if idx < len(finalResp.Content) { + switch event.Delta.Type { + case "text_delta": + finalResp.Content[idx].Text += event.Delta.Text + case "thinking_delta": + finalResp.Content[idx].Thinking += event.Delta.Thinking + case "input_json_delta": + finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON) + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_cc buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResp == nil { + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response") + return nil, fmt.Errorf("upstream stream ended without response") + } + + // Update usage from accumulated delta + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + finalResp.Usage = apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + // Chain: Anthropic → Responses → Chat Completions + responsesResp := apicompat.AnthropicToResponsesResponse(finalResp) + ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, ccResp) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each +// to Responses events, then to Chat Completions chunks, and writes them. +func (s *GatewayService) handleCCStreamingFromAnthropic( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, + includeUsage bool, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + // Use Anthropic→Responses state machine, then convert Responses→CC + anthState := apicompat.NewAnthropicEventToResponsesState() + anthState.Model = originalModel + ccState := apicompat.NewResponsesEventToChatState() + ccState.Model = originalModel + ccState.IncludeUsage = includeUsage + + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *ForwardResult { + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + return false + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + return true // client disconnected + } + return false + } + + processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // Extract usage from message_delta + if event.Type == "message_delta" && event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + // Also capture usage from message_start (carries cache fields) + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // Chain: Anthropic event → Responses events → CC chunks + responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState) + for _, resEvt := range responsesEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + if disconnected := writeChunk(chunk); disconnected { + return true + } + } + } + c.Writer.Flush() + return false + } + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + if processAnthropicEvent(&event) { + return resultWithUsage(), nil + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_cc stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Finalize both state machines + finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState) + for _, resEvt := range finalResEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + writeChunk(chunk) //nolint:errcheck + } + } + finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState) + for _, chunk := range finalCCChunks { + writeChunk(chunk) //nolint:errcheck + } + + // Write [DONE] marker + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + + return resultWithUsage(), nil +} + +// writeGatewayCCError writes an error in OpenAI Chat Completions format for +// the Anthropic-upstream CC forwarding path. +func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/gateway_forward_as_chat_completions_test.go b/backend/internal/service/gateway_forward_as_chat_completions_test.go new file mode 100644 index 00000000..5003e5b3 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractCCReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + t.Run("nested reasoning.effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + }) + + t.Run("flat reasoning_effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`)) + require.NotNil(t, got) + require.Equal(t, "xhigh", *got) + }) + + t.Run("missing effort", func(t *testing.T) { + require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`))) + }) +} + +func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "high" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) +} + +func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "medium" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "medium", *result.ReasoningEffort) + require.Contains(t, rec.Body.String(), `[DONE]`) +} diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go new file mode 100644 index 00000000..5dca57f9 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -0,0 +1,518 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardAsResponses accepts an OpenAI Responses API request body, converts it +// to Anthropic Messages format, forwards to the Anthropic upstream, and converts +// the response back to Responses format. This enables OpenAI Responses API +// clients to access Anthropic models through Anthropic platform groups. +// +// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic +// but in reverse direction: Responses → Anthropic upstream → Responses. +func (s *GatewayService) ForwardAsResponses( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse Responses request + var responsesReq apicompat.ResponsesRequest + if err := json.Unmarshal(body, &responsesReq); err != nil { + return nil, fmt.Errorf("parse responses request: %w", err) + } + originalModel := responsesReq.Model + clientStream := responsesReq.Stream + + // 2. Convert Responses → Anthropic + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq) + if err != nil { + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + + // 3. Force upstream streaming (Anthropic works best with streaming) + anthropicReq.Stream = true + reqStream := true + + // 4. Model mapping + mappedModel := originalModel + reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(originalModel) + if normalized != originalModel { + mappedModel = normalized + } + } + anthropicReq.Model = mappedModel + + logger.L().Debug("gateway forward_as_responses: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("client_stream", clientStream), + ) + + // 5. Marshal Anthropic request body + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints) + isClaudeCode := false // Responses API is never Claude Code + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + if !strings.Contains(strings.ToLower(mappedModel), "haiku") && + !systemIncludesClaudeCodePrompt(anthropicReq.System) { + anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) + } + } + + // 7. Enforce cache_control block limit + anthropicBody = enforceCacheControlLimit(anthropicBody) + + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 12. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + // Non-failover error: return Responses-formatted error to client + writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 13. Handle normal response (convert Anthropic → Responses) + var result *ForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } else { + result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } + + return result, handleErr +} + +// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort +// and normalizes it for usage logging. +func ExtractResponsesReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) { + if dst == nil { + return + } + if src.InputTokens > 0 { + dst.InputTokens = src.InputTokens + } + if src.OutputTokens > 0 { + dst.OutputTokens = src.OutputTokens + } + if src.CacheReadInputTokens > 0 { + dst.CacheReadInputTokens = src.CacheReadInputTokens + } + if src.CacheCreationInputTokens > 0 { + dst.CacheCreationInputTokens = src.CacheCreationInputTokens + } +} + +// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from +// the upstream streaming response, assembles them into a complete Anthropic +// response, converts to Responses API JSON format, and writes it to the client. +func (s *GatewayService) handleResponsesBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + // Accumulate the final Anthropic response from streaming events + var finalResp *apicompat.AnthropicResponse + var usage ClaudeUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + eventType := strings.TrimPrefix(line, "event: ") + + // Read the data line + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("forward_as_responses buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + zap.String("event_type", eventType), + ) + continue + } + + // message_start carries the initial response structure + if event.Type == "message_start" && event.Message != nil { + finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // message_delta carries final usage and stop_reason + if event.Type == "message_delta" { + if event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { + finalResp.StopReason = event.Delta.StopReason + } + } + + // Accumulate content blocks + if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil { + finalResp.Content = append(finalResp.Content, *event.ContentBlock) + } + if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil { + idx := *event.Index + if idx < len(finalResp.Content) { + switch event.Delta.Type { + case "text_delta": + finalResp.Content[idx].Text += event.Delta.Text + case "thinking_delta": + finalResp.Content[idx].Thinking += event.Delta.Thinking + case "input_json_delta": + finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON) + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_responses buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResp == nil { + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response") + return nil, fmt.Errorf("upstream stream ended without response") + } + + // Update usage from accumulated delta + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + finalResp.Usage = apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + // Convert to Responses format + responsesResp := apicompat.AnthropicToResponsesResponse(finalResp) + responsesResp.Model = originalModel // Use original model name + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, responsesResp) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleResponsesStreamingResponse reads Anthropic SSE events from upstream, +// converts each to Responses SSE events, and writes them to the client. +func (s *GatewayService) handleResponsesStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewAnthropicEventToResponsesState() + state.Model = originalModel + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *ForwardResult { + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + // processEvent handles a single parsed Anthropic SSE event. + processEvent := func(event *apicompat.AnthropicStreamEvent) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // Extract usage from message_delta + if event.Type == "message_delta" && event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + // Also capture usage from message_start + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // Convert to Responses events + events := apicompat.AnthropicEventToResponsesEvents(event, state) + for _, evt := range events { + sse, err := apicompat.ResponsesEventToSSE(evt) + if err != nil { + logger.L().Warn("forward_as_responses stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("forward_as_responses stream: client disconnected", + zap.String("request_id", requestID), + ) + return true // client disconnected + } + } + if len(events) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*ForwardResult, error) { + if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 { + for _, evt := range finalEvents { + sse, err := apicompat.ResponsesEventToSSE(evt) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + c.Writer.Flush() + } + return resultWithUsage(), nil + } + + // Read Anthropic SSE events + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + eventType := strings.TrimPrefix(line, "event: ") + + // Read data line + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("forward_as_responses stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + zap.String("event_type", eventType), + ) + continue + } + + if processEvent(&event) { + return resultWithUsage(), nil + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_responses stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + return finalizeStream() +} + +// appendRawJSON appends a JSON fragment string to existing raw JSON. +func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage { + if len(existing) == 0 { + return json.RawMessage(fragment) + } + return json.RawMessage(string(existing) + fragment) +} + +// writeResponsesError writes an error response in OpenAI Responses API format. +func writeResponsesError(c *gin.Context, statusCode int, code, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "code": code, + "message": message, + }, + }) +} + +// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes. +func mapUpstreamStatusCode(code int) int { + if code >= 500 { + return http.StatusBadGateway + } + return code +} diff --git a/backend/internal/service/gateway_forward_as_responses_test.go b/backend/internal/service/gateway_forward_as_responses_test.go new file mode 100644 index 00000000..e48d8b22 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses_test.go @@ -0,0 +1,94 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractResponsesReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + + require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`))) +} + +func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `"cached_tokens":9`) +} + +func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `response.completed`) +} diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 29b6cfd6..1d3e39e8 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "math" + "regexp" + "sort" "strings" "unsafe" @@ -34,6 +36,9 @@ var ( patternEmptyTextSpaced = []byte(`"text": ""`) patternEmptyTextSp1 = []byte(`"text" : ""`) patternEmptyTextSp2 = []byte(`"text" :""`) + + sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`) + sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -75,6 +80,49 @@ type ParsedRequest struct { OnUpstreamAccepted func() } +// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing. +// It preserves the set of product names from Product/Version tokens while +// discarding version-only changes and incidental comments. +func NormalizeSessionUserAgent(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + + matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1) + if len(matches) == 0 { + return normalizeSessionUserAgentFallback(raw) + } + + products := make([]string, 0, len(matches)) + seen := make(map[string]struct{}, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + product := strings.ToLower(strings.TrimSpace(match[1])) + if product == "" { + continue + } + if _, exists := seen[product]; exists { + continue + } + seen[product] = struct{}{} + products = append(products, product) + } + if len(products) == 0 { + return normalizeSessionUserAgentFallback(raw) + } + sort.Strings(products) + return strings.Join(products, "+") +} + +func normalizeSessionUserAgentFallback(raw string) string { + normalized := strings.ToLower(strings.Join(strings.Fields(raw), " ")) + normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "") + return strings.Join(strings.Fields(normalized), " ") +} + // ParseGatewayRequest 解析网关请求体并返回结构化结果。 // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // 不同协议使用不同的 system/messages 字段名。 diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 72cef2ac..9dd39276 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { if parsed.SessionContext != nil { _, _ = combined.WriteString(parsed.SessionContext.ClientIP) _, _ = combined.WriteString(":") - _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent)) _, _ = combined.WriteString(":") _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) _, _ = combined.WriteString("|") @@ -4148,6 +4148,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -4174,6 +4175,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "signature_error", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -4228,6 +4230,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: retryResp.StatusCode, UpstreamRequestID: retryResp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(retryReq.URL.String()), Kind: "signature_retry_thinking", Message: extractUpstreamErrorMessage(retryRespBody), Detail: func() string { @@ -4258,6 +4261,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(retryReq2.URL.String()), Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) @@ -4297,6 +4301,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "budget_constraint_error", Message: errMsg, Detail: func() string { @@ -4358,6 +4363,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -4628,6 +4634,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "request_error", Message: safeErr, @@ -4667,6 +4674,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "retry", Message: extractUpstreamErrorMessage(respBody), @@ -5344,6 +5352,7 @@ func (s *GatewayService) executeBedrockUpstream( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -5380,6 +5389,7 @@ func (s *GatewayService) executeBedrockUpstream( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -8064,6 +8074,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "request_error", Message: sanitizeUpstreamErrorMessage(err.Error()), @@ -8119,6 +8130,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "http_error", Message: upstreamMsg, diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 1780d1da..cd291328 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string { // 返回 16 字符的 Base64 编码的 SHA256 前缀 func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string { // 组合所有标识符 + normalizedUserAgent := NormalizeSessionUserAgent(userAgent) combined := strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(apiKeyID, 10) + ":" + ip + ":" + - userAgent + ":" + + normalizedUserAgent + ":" + platform + ":" + model diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index a034cddd..27321996 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } +func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro") + + if hash1 != hash2 { + t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2) + } +} + +func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro") + + if hash1 != hash2 { + t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2) + } +} + func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index f91fb4c9..39679c3d 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") } +func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0")) + h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1")) + require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash") +} + +func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0")) + h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1")) + require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash") +} + func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { svc := &GatewayService{} diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go index 2ed06d90..5fefb74f 100644 --- a/backend/internal/service/ops_models.go +++ b/backend/internal/service/ops_models.go @@ -62,6 +62,12 @@ type OpsErrorLog struct { ClientIP *string `json:"client_ip"` RequestPath string `json:"request_path"` Stream bool `json:"stream"` + + InboundEndpoint string `json:"inbound_endpoint"` + UpstreamEndpoint string `json:"upstream_endpoint"` + RequestedModel string `json:"requested_model"` + UpstreamModel string `json:"upstream_model"` + RequestType *int16 `json:"request_type"` } type OpsErrorLogDetail struct { diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 0ce9d425..04bf91c8 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct { Model string RequestPath string Stream bool + // InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions. + InboundEndpoint string + // UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses. + UpstreamEndpoint string + // RequestedModel is the client-requested model name before mapping. + RequestedModel string + // UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping. + UpstreamModel string + // RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2. + // Matches service.RequestType enum semantics from usage_log.go. + RequestType *int16 UserAgent string ErrorPhase string diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 9adf5896..05d444e1 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct { UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"` + // UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped). + // Helps debug 404/routing errors by showing which endpoint was targeted. + UpstreamURL string `json:"upstream_url,omitempty"` + // Best-effort upstream request capture (sanitized+trimmed). // Required for retrying a specific upstream attempt. UpstreamRequestBody string `json:"upstream_request_body,omitempty"` @@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody) ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody) ev.Kind = strings.TrimSpace(ev.Kind) + ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL) ev.Message = strings.TrimSpace(ev.Message) ev.Detail = strings.TrimSpace(ev.Detail) if ev.Message != "" { @@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) { } return out, nil } + +// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment +// to avoid leaking sensitive query parameters (e.g. OAuth tokens). +func safeUpstreamURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + if idx := strings.IndexByte(rawURL, '?'); idx >= 0 { + rawURL = rawURL[:idx] + } + if idx := strings.IndexByte(rawURL, '#'); idx >= 0 { + rawURL = rawURL[:idx] + } + return rawURL +} diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go index 50ceaa0e..fa6d1085 100644 --- a/backend/internal/service/ops_upstream_context_test.go +++ b/backend/internal/service/ops_upstream_context_test.go @@ -8,6 +8,27 @@ import ( "github.com/stretchr/testify/require" ) +func TestSafeUpstreamURL(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"}, + {"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"}, + {"strips both", "https://host/path?token=secret#x", "https://host/path"}, + {"no query or fragment", "https://host/path", "https://host/path"}, + {"empty string", "", ""}, + {"whitespace only", " ", ""}, + {"query before fragment", "https://h/p?a=1#f", "https://h/p"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, safeUpstreamURL(tt.input)) + }) + } +} + func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql b/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql new file mode 100644 index 00000000..56f83b84 --- /dev/null +++ b/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql @@ -0,0 +1,28 @@ +-- Ops error logs: add endpoint, model mapping, and request_type fields +-- to match usage_logs observability coverage. +-- +-- All columns are nullable with no default to preserve backward compatibility +-- with existing rows. + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256), + ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256); + +-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100), + ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); + +-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS request_type SMALLINT; + +COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.'; +COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.'; +COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).'; +COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.'; +COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.'; diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index 64f6a6d0..ac58eff4 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -969,6 +969,13 @@ export interface OpsErrorLog { client_ip?: string | null request_path?: string stream?: boolean + + // Error observability context (endpoint + model mapping) + inbound_endpoint?: string + upstream_endpoint?: string + requested_model?: string + upstream_model?: string + request_type?: number | null } export interface OpsErrorDetail extends OpsErrorLog { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index a2f69e2c..de40fa5e 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -3494,7 +3494,12 @@ export default { typeRequest: 'Request', typeAuth: 'Auth', typeRouting: 'Routing', - typeInternal: 'Internal' + typeInternal: 'Internal', + endpoint: 'Endpoint', + requestType: 'Type', + requestTypeSync: 'Sync', + requestTypeStream: 'Stream', + requestTypeWs: 'WS' }, // Error Details Modal errorDetails: { @@ -3580,6 +3585,16 @@ export default { latency: 'Request Duration', businessLimited: 'Business Limited', requestPath: 'Request Path', + inboundEndpoint: 'Inbound Endpoint', + upstreamEndpoint: 'Upstream Endpoint', + requestedModel: 'Requested Model', + upstreamModel: 'Upstream Model', + requestType: 'Request Type', + requestTypeUnknown: 'Unknown', + requestTypeSync: 'Sync', + requestTypeStream: 'Stream', + requestTypeWs: 'WebSocket', + modelMapping: 'Model Mapping', timings: 'Timings', auth: 'Auth', routing: 'Routing', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 2eef299c..d5ed956c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -3659,7 +3659,12 @@ export default { typeRequest: '请求', typeAuth: '认证', typeRouting: '路由', - typeInternal: '内部' + typeInternal: '内部', + endpoint: '端点', + requestType: '类型', + requestTypeSync: '同步', + requestTypeStream: '流式', + requestTypeWs: 'WS' }, // Error Details Modal errorDetails: { @@ -3745,6 +3750,16 @@ export default { latency: '请求时长', businessLimited: '业务限制', requestPath: '请求路径', + inboundEndpoint: '入站端点', + upstreamEndpoint: '上游端点', + requestedModel: '请求模型', + upstreamModel: '上游模型', + requestType: '请求类型', + requestTypeUnknown: '未知', + requestTypeSync: '同步', + requestTypeStream: '流式', + requestTypeWs: 'WebSocket', + modelMapping: '模型映射', timings: '时序信息', auth: '认证', routing: '路由', diff --git a/frontend/src/views/admin/ops/components/OpsErrorDetailModal.vue b/frontend/src/views/admin/ops/components/OpsErrorDetailModal.vue index a7edff96..d29607e5 100644 --- a/frontend/src/views/admin/ops/components/OpsErrorDetailModal.vue +++ b/frontend/src/views/admin/ops/components/OpsErrorDetailModal.vue @@ -59,7 +59,28 @@
{{ t('admin.ops.errorDetail.model') }}
- {{ detail.model || '—' }} + + +
+
+ +
+
{{ t('admin.ops.errorDetail.inboundEndpoint') }}
+
+ {{ detail.inbound_endpoint || '—' }} +
+
+ +
+
{{ t('admin.ops.errorDetail.upstreamEndpoint') }}
+
+ {{ detail.upstream_endpoint || '—' }}
@@ -72,6 +93,13 @@ +
+
{{ t('admin.ops.errorDetail.requestType') }}
+
+ {{ formatRequestTypeLabel(detail.request_type) }} +
+
+
{{ t('admin.ops.errorDetail.message') }}
@@ -213,6 +241,31 @@ function isUpstreamError(d: OpsErrorDetail | null): boolean { return phase === 'upstream' && owner === 'provider' } +function formatRequestTypeLabel(type: number | null | undefined): string { + switch (type) { + case 1: return t('admin.ops.errorDetail.requestTypeSync') + case 2: return t('admin.ops.errorDetail.requestTypeStream') + case 3: return t('admin.ops.errorDetail.requestTypeWs') + default: return t('admin.ops.errorDetail.requestTypeUnknown') + } +} + +function hasModelMapping(d: OpsErrorDetail | null): boolean { + if (!d) return false + const requested = String(d.requested_model || '').trim() + const upstream = String(d.upstream_model || '').trim() + return !!requested && !!upstream && requested !== upstream +} + +function displayModel(d: OpsErrorDetail | null): string { + if (!d) return '' + const upstream = String(d.upstream_model || '').trim() + if (upstream) return upstream + const requested = String(d.requested_model || '').trim() + if (requested) return requested + return String(d.model || '').trim() +} + const correlatedUpstream = ref([]) const correlatedUpstreamLoading = ref(false) diff --git a/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue b/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue index 28868552..2b3825a2 100644 --- a/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue +++ b/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue @@ -17,6 +17,9 @@ {{ t('admin.ops.errorLog.type') }} + + {{ t('admin.ops.errorLog.endpoint') }} + {{ t('admin.ops.errorLog.platform') }} @@ -42,7 +45,7 @@ - + {{ t('admin.ops.errorLog.noErrors') }} @@ -74,6 +77,18 @@ + + +
+ + + {{ log.inbound_endpoint }} + + + - +
+ + @@ -83,11 +98,22 @@ -
- - {{ log.model }} - - - +
+ +
@@ -138,6 +164,12 @@ > {{ log.severity }} + + {{ formatRequestType(log.request_type) }} +
@@ -193,6 +225,44 @@ function isUpstreamRow(log: OpsErrorLog): boolean { return phase === 'upstream' && owner === 'provider' } +function formatEndpointTooltip(log: OpsErrorLog): string { + const parts: string[] = [] + if (log.inbound_endpoint) parts.push(`Inbound: ${log.inbound_endpoint}`) + if (log.upstream_endpoint) parts.push(`Upstream: ${log.upstream_endpoint}`) + return parts.join('\n') || '' +} + +function hasModelMapping(log: OpsErrorLog): boolean { + const requested = String(log.requested_model || '').trim() + const upstream = String(log.upstream_model || '').trim() + return !!requested && !!upstream && requested !== upstream +} + +function modelMappingTooltip(log: OpsErrorLog): string { + const requested = String(log.requested_model || '').trim() + const upstream = String(log.upstream_model || '').trim() + if (!requested && !upstream) return '' + if (requested && upstream) return `${requested} → ${upstream}` + return upstream || requested +} + +function displayModel(log: OpsErrorLog): string { + const upstream = String(log.upstream_model || '').trim() + if (upstream) return upstream + const requested = String(log.requested_model || '').trim() + if (requested) return requested + return String(log.model || '').trim() +} + +function formatRequestType(type: number | null | undefined): string { + switch (type) { + case 1: return t('admin.ops.errorLog.requestTypeSync') + case 2: return t('admin.ops.errorLog.requestTypeStream') + case 3: return t('admin.ops.errorLog.requestTypeWs') + default: return '' + } +} + function getTypeBadge(log: OpsErrorLog): { label: string; className: string } { const phase = String(log.phase || '').toLowerCase() const owner = String(log.error_owner || '').toLowerCase() @@ -263,4 +333,4 @@ function formatSmartMessage(msg: string): string { return msg.length > 200 ? msg.substring(0, 200) + '...' : msg } - \ No newline at end of file +