diff --git a/.dockerignore b/.dockerignore index ab803d44..770145cb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -61,6 +61,9 @@ temp/ deploy/install.sh deploy/sub2api.service deploy/sub2api-sudoers +deploy/data/ +deploy/postgres_data/ +deploy/redis_data/ # GoReleaser .goreleaser.yaml diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 63c5ed0e..2c1ac5b0 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -114,6 +114,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() driveClient := repository.NewGeminiDriveClient() diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index fbac73d3..322ae590 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -352,7 +352,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "") if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index f762511c..9eaf0bfd 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -219,6 +219,7 @@ func (h *AccountHandler) List(c *gin.Context) { accountType := c.Query("type") status := c.Query("status") search := c.Query("search") + privacyMode := strings.TrimSpace(c.Query("privacy_mode")) // 标准化和验证 search 参数 search = strings.TrimSpace(search) if len(search) > 100 { @@ -244,7 +245,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode) if err != nil { response.ErrorFrom(c, err) return @@ -1936,7 +1937,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 61e2c2bd..4ed0a623 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -187,7 +187,7 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index c91566c8..b5a7eb77 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -110,6 +110,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -176,6 +177,7 @@ type UpdateSettingsRequest struct { PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -231,11 +233,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.DefaultBalance < 0 { req.DefaultBalance = 0 } + req.SMTPHost = strings.TrimSpace(req.SMTPHost) + req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) + req.SMTPPassword = strings.TrimSpace(req.SMTPPassword) + req.SMTPFrom = strings.TrimSpace(req.SMTPFrom) + req.SMTPFromName = strings.TrimSpace(req.SMTPFromName) if req.SMTPPort <= 0 { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 + // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 + if req.SMTPHost == "" && previousSettings.SMTPHost != "" { + req.SMTPHost = previousSettings.SMTPHost + req.SMTPPort = previousSettings.SMTPPort + req.SMTPUsername = previousSettings.SMTPUsername + req.SMTPFrom = previousSettings.SMTPFrom + req.SMTPFromName = previousSettings.SMTPFromName + req.SMTPUseTLS = previousSettings.SMTPUseTLS + } + // Turnstile 参数验证 if req.TurnstileEnabled { // 检查必填字段 @@ -417,6 +435,55 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { customMenuJSON = string(menuBytes) } + // 自定义端点验证 + const ( + maxCustomEndpoints = 10 + maxEndpointNameLen = 50 + maxEndpointURLLen = 2048 + maxEndpointDescriptionLen = 200 + ) + + customEndpointsJSON := previousSettings.CustomEndpoints + if req.CustomEndpoints != nil { + endpoints := *req.CustomEndpoints + if len(endpoints) > maxCustomEndpoints { + response.BadRequest(c, "Too many custom endpoints (max 10)") + return + } + for _, ep := range endpoints { + if strings.TrimSpace(ep.Name) == "" { + response.BadRequest(c, "Custom endpoint name is required") + return + } + if len(ep.Name) > maxEndpointNameLen { + response.BadRequest(c, "Custom endpoint name is too long (max 50 characters)") + return + } + if strings.TrimSpace(ep.Endpoint) == "" { + response.BadRequest(c, "Custom endpoint URL is required") + return + } + if len(ep.Endpoint) > maxEndpointURLLen { + response.BadRequest(c, "Custom endpoint URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(ep.Endpoint)); err != nil { + response.BadRequest(c, "Custom endpoint URL must be an absolute http(s) URL") + return + } + if len(ep.Description) > maxEndpointDescriptionLen { + response.BadRequest(c, "Custom endpoint description is too long (max 200 characters)") + return + } + } + endpointBytes, err := json.Marshal(endpoints) + if err != nil { + response.BadRequest(c, "Failed to serialize custom endpoints") + return + } + customEndpointsJSON = string(endpointBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -495,6 +562,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionURL: purchaseURL, SoraClientEnabled: req.SoraClientEnabled, CustomMenuItems: customMenuJSON, + CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, DefaultSubscriptions: defaultSubscriptions, @@ -592,6 +660,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, DefaultSubscriptions: updatedDefaultSubscriptions, @@ -828,7 +897,7 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { - SMTPHost string `json:"smtp_host" binding:"required"` + SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` SMTPUsername string `json:"smtp_username"` SMTPPassword string `json:"smtp_password"` @@ -844,18 +913,35 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { return } - if req.SMTPPort <= 0 { - req.SMTPPort = 587 + req.SMTPHost = strings.TrimSpace(req.SMTPHost) + req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) + + var savedConfig *service.SMTPConfig + if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil { + savedConfig = cfg } - // 如果未提供密码,从数据库获取已保存的密码 - password := req.SMTPPassword - if password == "" { - savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) - if err == nil && savedConfig != nil { - password = savedConfig.Password + if req.SMTPHost == "" && savedConfig != nil { + req.SMTPHost = savedConfig.Host + } + if req.SMTPPort <= 0 { + if savedConfig != nil && savedConfig.Port > 0 { + req.SMTPPort = savedConfig.Port + } else { + req.SMTPPort = 587 } } + if req.SMTPUsername == "" && savedConfig != nil { + req.SMTPUsername = savedConfig.Username + } + password := strings.TrimSpace(req.SMTPPassword) + if password == "" && savedConfig != nil { + password = savedConfig.Password + } + if req.SMTPHost == "" { + response.BadRequest(c, "SMTP host is required") + return + } config := &service.SMTPConfig{ Host: req.SMTPHost, @@ -877,7 +963,7 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) { // SendTestEmailRequest 发送测试邮件请求 type SendTestEmailRequest struct { Email string `json:"email" binding:"required,email"` - SMTPHost string `json:"smtp_host" binding:"required"` + SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` SMTPUsername string `json:"smtp_username"` SMTPPassword string `json:"smtp_password"` @@ -895,18 +981,43 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { return } - if req.SMTPPort <= 0 { - req.SMTPPort = 587 + req.SMTPHost = strings.TrimSpace(req.SMTPHost) + req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) + req.SMTPFrom = strings.TrimSpace(req.SMTPFrom) + req.SMTPFromName = strings.TrimSpace(req.SMTPFromName) + + var savedConfig *service.SMTPConfig + if cfg, err := h.emailService.GetSMTPConfig(c.Request.Context()); err == nil && cfg != nil { + savedConfig = cfg } - // 如果未提供密码,从数据库获取已保存的密码 - password := req.SMTPPassword - if password == "" { - savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context()) - if err == nil && savedConfig != nil { - password = savedConfig.Password + if req.SMTPHost == "" && savedConfig != nil { + req.SMTPHost = savedConfig.Host + } + if req.SMTPPort <= 0 { + if savedConfig != nil && savedConfig.Port > 0 { + req.SMTPPort = savedConfig.Port + } else { + req.SMTPPort = 587 } } + if req.SMTPUsername == "" && savedConfig != nil { + req.SMTPUsername = savedConfig.Username + } + password := strings.TrimSpace(req.SMTPPassword) + if password == "" && savedConfig != nil { + password = savedConfig.Password + } + if req.SMTPFrom == "" && savedConfig != nil { + req.SMTPFrom = savedConfig.From + } + if req.SMTPFromName == "" && savedConfig != nil { + req.SMTPFromName = savedConfig.FromName + } + if req.SMTPHost == "" { + response.BadRequest(c, "SMTP host is required") + return + } config := &service.SMTPConfig{ Host: req.SMTPHost, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 0f4f8fdc..7ea34aa0 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -15,6 +15,13 @@ type CustomMenuItem struct { SortOrder int `json:"sort_order"` } +// CustomEndpoint represents an admin-configured API endpoint for quick copy. +type CustomEndpoint struct { + Name string `json:"name"` + Endpoint string `json:"endpoint"` + Description string `json:"description"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { RegistrationEnabled bool `json:"registration_enabled"` @@ -56,6 +63,7 @@ type SystemSettings struct { PurchaseSubscriptionURL string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -114,6 +122,7 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` @@ -218,3 +227,17 @@ func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { } return filtered } + +// ParseCustomEndpoints parses a JSON string into a slice of CustomEndpoint. +// Returns empty slice on empty/invalid input. +func ParseCustomEndpoints(raw string) []CustomEndpoint { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomEndpoint{} + } + var items []CustomEndpoint + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomEndpoint{} + } + return items +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index e1b1b9a8..b9285c04 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -178,6 +178,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 验证 model 必填 if reqModel == "" { @@ -1396,6 +1397,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false))) // 获取订阅信息(可能为nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go new file mode 100644 index 00000000..da376036 --- /dev/null +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -0,0 +1,289 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ChatCompletions handles OpenAI Chat Completions API endpoint for Anthropic platform groups. +// POST /v1/chat/completions +// This converts Chat Completions requests to Anthropic format (via Responses format chain), +// forwards to Anthropic upstream, and converts responses back to Chat Completions format. +func (h *GatewayHandler) ChatCompletions(c *gin.Context) { + streamStarted := false + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read request body + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.chatCompletionsErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // Validate JSON + if !gjson.ValidBytes(body) { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // Extract model and stream + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + + // Claude Code only restriction + if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { + h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", + "This group is restricted to Claude Code clients (/v1/messages only)") + return + } + + // Error passthrough binding + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + // 1. Acquire user concurrency slot + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("gateway.cc.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + h.chatCompletionsErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("gateway.cc.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2. Re-check billing + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.chatCompletionsErrorResponse(c, status, code, message) + return + } + + // Parse request for session hash + parsedReq, _ := service.ParseGatewayRequest(body, "chat_completions") + if parsedReq == nil { + parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body} + } + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + + // 3. Account selection + failover loop + fs := NewFailoverState(h.maxAccountSwitches, false) + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + if err != nil { + if len(fs.FailedAccountIDs) == 0 { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + continue + case FailoverCanceled: + return + default: + if fs.LastFailoverErr != nil { + h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + } else { + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted") + } + return + } + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + + // 4. Acquire account concurrency slot + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("gateway.cc.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 5. Forward request + writerSizeBeforeForward := c.Writer.Size() + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) + + if accountReleaseFunc != nil { + accountReleaseFunc() + } + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleCCFailoverExhausted(c, failoverErr, true) + return + } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleCCFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + return + case FailoverCanceled: + return + } + } + h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.cc.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + // 6. Record usage + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("gateway.cc.record_usage_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + } + }) + return + } +} + +// chatCompletionsErrorResponse writes an error in OpenAI Chat Completions format. +func (h *GatewayHandler) chatCompletionsErrorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// handleCCFailoverExhausted writes a failover-exhausted error in CC format. +func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) { + if streamStarted { + return + } + statusCode := http.StatusBadGateway + if lastErr != nil && lastErr.StatusCode > 0 { + statusCode = lastErr.StatusCode + } + h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") +} diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go new file mode 100644 index 00000000..d146d724 --- /dev/null +++ b/backend/internal/handler/gateway_handler_responses.go @@ -0,0 +1,295 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// Responses handles OpenAI Responses API endpoint for Anthropic platform groups. +// POST /v1/responses +// This converts Responses API requests to Anthropic format, forwards to Anthropic +// upstream, and converts responses back to Responses format. +func (h *GatewayHandler) Responses(c *gin.Context) { + streamStarted := false + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.gateway.responses", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read request body + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + + if len(body) == 0 { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // Validate JSON + if !gjson.ValidBytes(body) { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // Extract model and stream using gjson (like OpenAI handler) + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + + // Claude Code only restriction: + // /v1/responses is never a Claude Code endpoint. + // When claude_code_only is enabled, this endpoint is rejected. + // The existing service-layer checkClaudeCodeRestriction handles degradation + // to fallback groups when the Forward path calls SelectAccountForModelWithExclusions. + // Here we just reject at handler level since /v1/responses clients can't be Claude Code. + if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { + h.responsesErrorResponse(c, http.StatusForbidden, "permission_error", + "This group is restricted to Claude Code clients (/v1/messages only)") + return + } + + // Error passthrough binding + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + + // 1. Acquire user concurrency slot + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2. Re-check billing + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.responsesErrorResponse(c, status, code, message) + return + } + + // Parse request for session hash + parsedReq, _ := service.ParseGatewayRequest(body, "responses") + if parsedReq == nil { + parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body} + } + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + + // 3. Account selection + failover loop + fs := NewFailoverState(h.maxAccountSwitches, false) + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + if err != nil { + if len(fs.FailedAccountIDs) == 0 { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + return + } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + continue + case FailoverCanceled: + return + default: + if fs.LastFailoverErr != nil { + h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + } else { + h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted") + } + return + } + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + + // 4. Acquire account concurrency slot + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 5. Forward request + writerSizeBeforeForward := c.Writer.Size() + result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) + + if accountReleaseFunc != nil { + accountReleaseFunc() + } + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + // Can't failover if streaming content already sent + if c.Writer.Size() != writerSizeBeforeForward { + h.handleResponsesFailoverExhausted(c, failoverErr, true) + return + } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted) + return + case FailoverCanceled: + return + } + } + h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.responses.forward_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + return + } + + // 6. Record usage + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("gateway.responses.record_usage_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + ) + } + }) + return + } +} + +// responsesErrorResponse writes an error in OpenAI Responses API format. +func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": code, + "message": message, + }, + }) +} + +// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format. +func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) { + if streamStarted { + return // Can't write error after stream started + } + statusCode := http.StatusBadGateway + if lastErr != nil && lastErr.StatusCode > 0 { + statusCode = lastErr.StatusCode + } + h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index fb231898..5dc03b6d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -182,6 +182,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } setOpsRequestContext(c, modelName, stream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index dd158d8b..0c94aa21 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -77,6 +77,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index b7f18d21..3ce6e5d6 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -183,6 +183,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { @@ -545,6 +546,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { @@ -1096,6 +1098,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { zap.String("previous_response_id_kind", previousResponseIDKind), ) setOpsRequestContext(c, reqModel, true, firstMessage) + setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) var currentUserRelease func() var currentAccountRelease func() diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index ceb06f0e..90e90dd0 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -27,6 +27,9 @@ const ( opsRequestBodyKey = "ops_request_body" opsAccountIDKey = "ops_account_id" + opsUpstreamModelKey = "ops_upstream_model" + opsRequestTypeKey = "ops_request_type" + // 错误过滤匹配常量 — shouldSkipOpsErrorLog 和错误分类共用 opsErrContextCanceled = "context canceled" opsErrNoAvailableAccounts = "no available accounts" @@ -345,6 +348,18 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody } } +// setOpsEndpointContext stores upstream model and request type for ops error logging. +// Called by handlers after model mapping and request type determination. +func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int16) { + if c == nil { + return + } + if upstreamModel = strings.TrimSpace(upstreamModel); upstreamModel != "" { + c.Set(opsUpstreamModelKey, upstreamModel) + } + c.Set(opsRequestTypeKey, requestType) +} + func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { if c == nil || entry == nil { return @@ -628,7 +643,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } return "" }(), - Stream: stream, + Stream: stream, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, platform), + RequestedModel: modelName, + UpstreamModel: func() string { + if v, ok := c.Get(opsUpstreamModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" + }(), + RequestType: func() *int16 { + if v, ok := c.Get(opsRequestTypeKey); ok { + switch t := v.(type) { + case int16: + return &t + case int: + v16 := int16(t) + return &v16 + } + } + return nil + }(), UserAgent: c.GetHeader("User-Agent"), ErrorPhase: "upstream", @@ -756,7 +794,30 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } return "" }(), - Stream: stream, + Stream: stream, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, platform), + RequestedModel: modelName, + UpstreamModel: func() string { + if v, ok := c.Get(opsUpstreamModelKey); ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + return "" + }(), + RequestType: func() *int16 { + if v, ok := c.Get(opsRequestTypeKey); ok { + switch t := v.(type) { + case int16: + return &t + case int: + v16 := int16(t) + return &v16 + } + } + return nil + }(), UserAgent: c.GetHeader("User-Agent"), ErrorPhase: phase, diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 679dd4ce..6ae45110 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -274,3 +274,48 @@ func TestNormalizeOpsErrorType(t *testing.T) { }) } } + +func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + setOpsEndpointContext(c, "claude-3-5-sonnet-20241022", int16(2)) // stream + + v, ok := c.Get(opsUpstreamModelKey) + require.True(t, ok) + vStr, ok := v.(string) + require.True(t, ok) + require.Equal(t, "claude-3-5-sonnet-20241022", vStr) + + rt, ok := c.Get(opsRequestTypeKey) + require.True(t, ok) + rtVal, ok := rt.(int16) + require.True(t, ok) + require.Equal(t, int16(2), rtVal) +} + +func TestSetOpsEndpointContext_EmptyModelNotStored(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + setOpsEndpointContext(c, "", int16(1)) + + _, ok := c.Get(opsUpstreamModelKey) + require.False(t, ok, "empty upstream model should not be stored") + + rt, ok := c.Get(opsRequestTypeKey) + require.True(t, ok) + rtVal, ok := rt.(int16) + require.True(t, ok) + require.Equal(t, int16(1), rtVal) +} + +func TestSetOpsEndpointContext_NilContext(t *testing.T) { + require.NotPanics(t, func() { + setOpsEndpointContext(nil, "model", int16(1)) + }) +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 92061895..2c999cf1 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -52,6 +52,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), + CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 89dcd394..625f159d 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2072,7 +2072,7 @@ func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { +func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index cc1b1c0b..5e505409 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -159,6 +159,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { } setOpsRequestContext(c, reqModel, clientStream, body) + setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false))) platform := "" if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 5c631132..084e4ae1 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -130,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { +func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses_response.go b/backend/internal/pkg/apicompat/anthropic_to_responses_response.go new file mode 100644 index 00000000..9290e399 --- /dev/null +++ b/backend/internal/pkg/apicompat/anthropic_to_responses_response.go @@ -0,0 +1,521 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: AnthropicResponse → ResponsesResponse +// --------------------------------------------------------------------------- + +// AnthropicToResponsesResponse converts an Anthropic Messages response into a +// Responses API response. This is the reverse of ResponsesToAnthropic and +// enables Anthropic upstream responses to be returned in OpenAI Responses format. +func AnthropicToResponsesResponse(resp *AnthropicResponse) *ResponsesResponse { + id := resp.ID + if id == "" { + id = generateResponsesID() + } + + out := &ResponsesResponse{ + ID: id, + Object: "response", + Model: resp.Model, + } + + var outputs []ResponsesOutput + var msgParts []ResponsesContentPart + + for _, block := range resp.Content { + switch block.Type { + case "thinking": + if block.Thinking != "" { + outputs = append(outputs, ResponsesOutput{ + Type: "reasoning", + ID: generateItemID(), + Summary: []ResponsesSummary{{ + Type: "summary_text", + Text: block.Thinking, + }}, + }) + } + case "text": + if block.Text != "" { + msgParts = append(msgParts, ResponsesContentPart{ + Type: "output_text", + Text: block.Text, + }) + } + case "tool_use": + args := "{}" + if len(block.Input) > 0 { + args = string(block.Input) + } + outputs = append(outputs, ResponsesOutput{ + Type: "function_call", + ID: generateItemID(), + CallID: toResponsesCallID(block.ID), + Name: block.Name, + Arguments: args, + Status: "completed", + }) + } + } + + // Assemble message output item from text parts + if len(msgParts) > 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: msgParts, + Status: "completed", + }) + } + + if len(outputs) == 0 { + outputs = append(outputs, ResponsesOutput{ + Type: "message", + ID: generateItemID(), + Role: "assistant", + Content: []ResponsesContentPart{{Type: "output_text", Text: ""}}, + Status: "completed", + }) + } + out.Output = outputs + + // Map stop_reason → status + out.Status = anthropicStopReasonToResponsesStatus(resp.StopReason, resp.Content) + if out.Status == "incomplete" { + out.IncompleteDetails = &ResponsesIncompleteDetails{Reason: "max_output_tokens"} + } + + // Usage + out.Usage = &ResponsesUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.CacheReadInputTokens > 0 { + out.Usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: resp.Usage.CacheReadInputTokens, + } + } + + return out +} + +// anthropicStopReasonToResponsesStatus maps Anthropic stop_reason to Responses status. +func anthropicStopReasonToResponsesStatus(stopReason string, blocks []AnthropicContentBlock) string { + switch stopReason { + case "max_tokens": + return "incomplete" + case "end_turn", "tool_use", "stop_sequence": + return "completed" + default: + return "completed" + } +} + +// --------------------------------------------------------------------------- +// Streaming: AnthropicStreamEvent → []ResponsesStreamEvent (stateful converter) +// --------------------------------------------------------------------------- + +// AnthropicEventToResponsesState tracks state for converting a sequence of +// Anthropic SSE events into Responses SSE events. +type AnthropicEventToResponsesState struct { + ResponseID string + Model string + Created int64 + SequenceNumber int + + // CreatedSent tracks whether response.created has been emitted. + CreatedSent bool + // CompletedSent tracks whether the terminal event has been emitted. + CompletedSent bool + + // Current output tracking + OutputIndex int + CurrentItemID string + CurrentItemType string // "message" | "function_call" | "reasoning" + + // For message output: accumulate text parts + ContentIndex int + + // For function_call: track per-output info + CurrentCallID string + CurrentName string + + // Usage from message_delta + InputTokens int + OutputTokens int + CacheReadInputTokens int +} + +// NewAnthropicEventToResponsesState returns an initialised stream state. +func NewAnthropicEventToResponsesState() *AnthropicEventToResponsesState { + return &AnthropicEventToResponsesState{ + Created: time.Now().Unix(), + } +} + +// AnthropicEventToResponsesEvents converts a single Anthropic SSE event into +// zero or more Responses SSE events, updating state as it goes. +func AnthropicEventToResponsesEvents( + evt *AnthropicStreamEvent, + state *AnthropicEventToResponsesState, +) []ResponsesStreamEvent { + switch evt.Type { + case "message_start": + return anthToResHandleMessageStart(evt, state) + case "content_block_start": + return anthToResHandleContentBlockStart(evt, state) + case "content_block_delta": + return anthToResHandleContentBlockDelta(evt, state) + case "content_block_stop": + return anthToResHandleContentBlockStop(evt, state) + case "message_delta": + return anthToResHandleMessageDelta(evt, state) + case "message_stop": + return anthToResHandleMessageStop(state) + default: + return nil + } +} + +// FinalizeAnthropicResponsesStream emits synthetic termination events if the +// stream ended without a proper message_stop. +func FinalizeAnthropicResponsesStream(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if !state.CreatedSent || state.CompletedSent { + return nil + } + + var events []ResponsesStreamEvent + + // Close any open item + events = append(events, closeCurrentResponsesItem(state)...) + + // Emit response.completed + events = append(events, makeResponsesCompletedEvent(state, "completed", nil)) + state.CompletedSent = true + return events +} + +// ResponsesEventToSSE formats a ResponsesStreamEvent as an SSE data line. +func ResponsesEventToSSE(evt ResponsesStreamEvent) (string, error) { + data, err := json.Marshal(evt) + if err != nil { + return "", err + } + return fmt.Sprintf("event: %s\ndata: %s\n\n", evt.Type, data), nil +} + +// --- internal handlers --- + +func anthToResHandleMessageStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.Message != nil { + state.ResponseID = evt.Message.ID + if state.Model == "" { + state.Model = evt.Message.Model + } + if evt.Message.Usage.InputTokens > 0 { + state.InputTokens = evt.Message.Usage.InputTokens + } + } + + if state.CreatedSent { + return nil + } + state.CreatedSent = true + + // Emit response.created + return []ResponsesStreamEvent{makeResponsesCreatedEvent(state)} +} + +func anthToResHandleContentBlockStart(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.ContentBlock == nil { + return nil + } + + var events []ResponsesStreamEvent + + switch evt.ContentBlock.Type { + case "thinking": + state.CurrentItemID = generateItemID() + state.CurrentItemType = "reasoning" + state.ContentIndex = 0 + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "reasoning", + ID: state.CurrentItemID, + }, + })) + + case "text": + // If we don't have an open message item, open one + if state.CurrentItemType != "message" { + state.CurrentItemID = generateItemID() + state.CurrentItemType = "message" + state.ContentIndex = 0 + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "message", + ID: state.CurrentItemID, + Role: "assistant", + Status: "in_progress", + }, + })) + } + + case "tool_use": + // Close previous item if any + events = append(events, closeCurrentResponsesItem(state)...) + + state.CurrentItemID = generateItemID() + state.CurrentItemType = "function_call" + state.CurrentCallID = toResponsesCallID(evt.ContentBlock.ID) + state.CurrentName = evt.ContentBlock.Name + + events = append(events, makeResponsesEvent(state, "response.output_item.added", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Item: &ResponsesOutput{ + Type: "function_call", + ID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + Status: "in_progress", + }, + })) + } + + return events +} + +func anthToResHandleContentBlockDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if evt.Delta == nil { + return nil + } + + switch evt.Delta.Type { + case "text_delta": + if evt.Delta.Text == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_text.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ContentIndex: state.ContentIndex, + Delta: evt.Delta.Text, + ItemID: state.CurrentItemID, + })} + + case "thinking_delta": + if evt.Delta.Thinking == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.reasoning_summary_text.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + SummaryIndex: 0, + Delta: evt.Delta.Thinking, + ItemID: state.CurrentItemID, + })} + + case "input_json_delta": + if evt.Delta.PartialJSON == "" { + return nil + } + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.function_call_arguments.delta", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + Delta: evt.Delta.PartialJSON, + ItemID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + })} + + case "signature_delta": + // Anthropic signature deltas have no Responses equivalent; skip + return nil + } + + return nil +} + +func anthToResHandleContentBlockStop(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + switch state.CurrentItemType { + case "reasoning": + // Emit reasoning summary done + output item done + events := []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.reasoning_summary_text.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + SummaryIndex: 0, + ItemID: state.CurrentItemID, + }), + } + events = append(events, closeCurrentResponsesItem(state)...) + return events + + case "function_call": + // Emit function_call_arguments.done + output item done + events := []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.function_call_arguments.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ItemID: state.CurrentItemID, + CallID: state.CurrentCallID, + Name: state.CurrentName, + }), + } + events = append(events, closeCurrentResponsesItem(state)...) + return events + + case "message": + // Emit output_text.done (text block is done, but message item stays open for potential more blocks) + return []ResponsesStreamEvent{ + makeResponsesEvent(state, "response.output_text.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex, + ContentIndex: state.ContentIndex, + ItemID: state.CurrentItemID, + }), + } + } + + return nil +} + +func anthToResHandleMessageDelta(evt *AnthropicStreamEvent, state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + // Update usage + if evt.Usage != nil { + state.OutputTokens = evt.Usage.OutputTokens + if evt.Usage.CacheReadInputTokens > 0 { + state.CacheReadInputTokens = evt.Usage.CacheReadInputTokens + } + } + + return nil +} + +func anthToResHandleMessageStop(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if state.CompletedSent { + return nil + } + + var events []ResponsesStreamEvent + + // Close any open item + events = append(events, closeCurrentResponsesItem(state)...) + + // Determine status + status := "completed" + var incompleteDetails *ResponsesIncompleteDetails + + // Emit response.completed + events = append(events, makeResponsesCompletedEvent(state, status, incompleteDetails)) + state.CompletedSent = true + return events +} + +// --- helper functions --- + +func closeCurrentResponsesItem(state *AnthropicEventToResponsesState) []ResponsesStreamEvent { + if state.CurrentItemType == "" { + return nil + } + + itemType := state.CurrentItemType + itemID := state.CurrentItemID + + // Reset + state.CurrentItemType = "" + state.CurrentItemID = "" + state.CurrentCallID = "" + state.CurrentName = "" + state.OutputIndex++ + state.ContentIndex = 0 + + return []ResponsesStreamEvent{makeResponsesEvent(state, "response.output_item.done", &ResponsesStreamEvent{ + OutputIndex: state.OutputIndex - 1, // Use the index before increment + Item: &ResponsesOutput{ + Type: itemType, + ID: itemID, + Status: "completed", + }, + })} +} + +func makeResponsesCreatedEvent(state *AnthropicEventToResponsesState) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + return ResponsesStreamEvent{ + Type: "response.created", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: "in_progress", + Output: []ResponsesOutput{}, + }, + } +} + +func makeResponsesCompletedEvent( + state *AnthropicEventToResponsesState, + status string, + incompleteDetails *ResponsesIncompleteDetails, +) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + usage := &ResponsesUsage{ + InputTokens: state.InputTokens, + OutputTokens: state.OutputTokens, + TotalTokens: state.InputTokens + state.OutputTokens, + } + if state.CacheReadInputTokens > 0 { + usage.InputTokensDetails = &ResponsesInputTokensDetails{ + CachedTokens: state.CacheReadInputTokens, + } + } + + return ResponsesStreamEvent{ + Type: "response.completed", + SequenceNumber: seq, + Response: &ResponsesResponse{ + ID: state.ResponseID, + Object: "response", + Model: state.Model, + Status: status, + Output: []ResponsesOutput{}, // Simplified; full output tracking would add complexity + Usage: usage, + IncompleteDetails: incompleteDetails, + }, + } +} + +func makeResponsesEvent(state *AnthropicEventToResponsesState, eventType string, template *ResponsesStreamEvent) ResponsesStreamEvent { + seq := state.SequenceNumber + state.SequenceNumber++ + + evt := *template + evt.Type = eventType + evt.SequenceNumber = seq + return evt +} + +func generateResponsesID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "resp_" + hex.EncodeToString(b) +} + +func generateItemID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "item_" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go new file mode 100644 index 00000000..f0a5b07e --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go @@ -0,0 +1,464 @@ +package apicompat + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ResponsesToAnthropicRequest converts a Responses API request into an +// Anthropic Messages request. This is the reverse of AnthropicToResponses and +// enables Anthropic platform groups to accept OpenAI Responses API requests +// by converting them to the native /v1/messages format before forwarding upstream. +func ResponsesToAnthropicRequest(req *ResponsesRequest) (*AnthropicRequest, error) { + system, messages, err := convertResponsesInputToAnthropic(req.Input) + if err != nil { + return nil, err + } + + out := &AnthropicRequest{ + Model: req.Model, + Messages: messages, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: req.Stream, + } + + if len(system) > 0 { + out.System = system + } + + // max_output_tokens → max_tokens + if req.MaxOutputTokens != nil && *req.MaxOutputTokens > 0 { + out.MaxTokens = *req.MaxOutputTokens + } + if out.MaxTokens == 0 { + // Anthropic requires max_tokens; default to a sensible value. + out.MaxTokens = 8192 + } + + // Convert tools + if len(req.Tools) > 0 { + out.Tools = convertResponsesToAnthropicTools(req.Tools) + } + + // Convert tool_choice (reverse of convertAnthropicToolChoiceToResponses) + if len(req.ToolChoice) > 0 { + tc, err := convertResponsesToAnthropicToolChoice(req.ToolChoice) + if err != nil { + return nil, fmt.Errorf("convert tool_choice: %w", err) + } + out.ToolChoice = tc + } + + // reasoning.effort → output_config.effort + thinking + if req.Reasoning != nil && req.Reasoning.Effort != "" { + effort := mapResponsesEffortToAnthropic(req.Reasoning.Effort) + out.OutputConfig = &AnthropicOutputConfig{Effort: effort} + // Enable thinking for non-low efforts + if effort != "low" { + out.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: defaultThinkingBudget(effort), + } + } + } + + return out, nil +} + +// defaultThinkingBudget returns a sensible thinking budget based on effort level. +func defaultThinkingBudget(effort string) int { + switch effort { + case "low": + return 1024 + case "medium": + return 4096 + case "high": + return 10240 + case "max": + return 32768 + default: + return 10240 + } +} + +// mapResponsesEffortToAnthropic converts OpenAI Responses reasoning effort to +// Anthropic effort levels. Reverse of mapAnthropicEffortToResponses. +// +// low → low +// medium → medium +// high → high +// xhigh → max +func mapResponsesEffortToAnthropic(effort string) string { + if effort == "xhigh" { + return "max" + } + return effort // low→low, medium→medium, high→high, unknown→passthrough +} + +// convertResponsesInputToAnthropic extracts system prompt and messages from +// a Responses API input array. Returns the system as raw JSON (for Anthropic's +// polymorphic system field) and a list of Anthropic messages. +func convertResponsesInputToAnthropic(inputRaw json.RawMessage) (json.RawMessage, []AnthropicMessage, error) { + // Try as plain string input. + var inputStr string + if err := json.Unmarshal(inputRaw, &inputStr); err == nil { + content, _ := json.Marshal(inputStr) + return nil, []AnthropicMessage{{Role: "user", Content: content}}, nil + } + + var items []ResponsesInputItem + if err := json.Unmarshal(inputRaw, &items); err != nil { + return nil, nil, fmt.Errorf("parse responses input: %w", err) + } + + var system json.RawMessage + var messages []AnthropicMessage + + for _, item := range items { + switch { + case item.Role == "system": + // System prompt → Anthropic system field + text := extractTextFromContent(item.Content) + if text != "" { + system, _ = json.Marshal(text) + } + + case item.Type == "function_call": + // function_call → assistant message with tool_use block + input := json.RawMessage("{}") + if item.Arguments != "" { + input = json.RawMessage(item.Arguments) + } + block := AnthropicContentBlock{ + Type: "tool_use", + ID: fromResponsesCallIDToAnthropic(item.CallID), + Name: item.Name, + Input: input, + } + blockJSON, _ := json.Marshal([]AnthropicContentBlock{block}) + messages = append(messages, AnthropicMessage{ + Role: "assistant", + Content: blockJSON, + }) + + case item.Type == "function_call_output": + // function_call_output → user message with tool_result block + outputContent := item.Output + if outputContent == "" { + outputContent = "(empty)" + } + contentJSON, _ := json.Marshal(outputContent) + block := AnthropicContentBlock{ + Type: "tool_result", + ToolUseID: fromResponsesCallIDToAnthropic(item.CallID), + Content: contentJSON, + } + blockJSON, _ := json.Marshal([]AnthropicContentBlock{block}) + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: blockJSON, + }) + + case item.Role == "user": + content, err := convertResponsesUserToAnthropicContent(item.Content) + if err != nil { + return nil, nil, err + } + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: content, + }) + + case item.Role == "assistant": + content, err := convertResponsesAssistantToAnthropicContent(item.Content) + if err != nil { + return nil, nil, err + } + messages = append(messages, AnthropicMessage{ + Role: "assistant", + Content: content, + }) + + default: + // Unknown role/type — attempt as user message + if item.Content != nil { + messages = append(messages, AnthropicMessage{ + Role: "user", + Content: item.Content, + }) + } + } + } + + // Merge consecutive same-role messages (Anthropic requires alternating roles) + messages = mergeConsecutiveMessages(messages) + + return system, messages, nil +} + +// extractTextFromContent extracts text from a content field that may be a +// plain string or an array of content parts. +func extractTextFromContent(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err == nil { + var texts []string + for _, p := range parts { + if (p.Type == "input_text" || p.Type == "output_text" || p.Type == "text") && p.Text != "" { + texts = append(texts, p.Text) + } + } + return strings.Join(texts, "\n\n") + } + return "" +} + +// convertResponsesUserToAnthropicContent converts a Responses user message +// content field into Anthropic content blocks JSON. +func convertResponsesUserToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return json.Marshal("") // empty string content + } + + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Array of content parts → Anthropic content blocks. + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + // Pass through as-is if we can't parse + return raw, nil + } + + var blocks []AnthropicContentBlock + for _, p := range parts { + switch p.Type { + case "input_text", "text": + if p.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: p.Text, + }) + } + case "input_image": + src := dataURIToAnthropicImageSource(p.ImageURL) + if src != nil { + blocks = append(blocks, AnthropicContentBlock{ + Type: "image", + Source: src, + }) + } + } + } + + if len(blocks) == 0 { + return json.Marshal("") + } + return json.Marshal(blocks) +} + +// convertResponsesAssistantToAnthropicContent converts a Responses assistant +// message content field into Anthropic content blocks JSON. +func convertResponsesAssistantToAnthropicContent(raw json.RawMessage) (json.RawMessage, error) { + if len(raw) == 0 { + return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: ""}}) + } + + // Try plain string. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal([]AnthropicContentBlock{{Type: "text", Text: s}}) + } + + // Array of content parts → Anthropic content blocks. + var parts []ResponsesContentPart + if err := json.Unmarshal(raw, &parts); err != nil { + return raw, nil + } + + var blocks []AnthropicContentBlock + for _, p := range parts { + switch p.Type { + case "output_text", "text": + if p.Text != "" { + blocks = append(blocks, AnthropicContentBlock{ + Type: "text", + Text: p.Text, + }) + } + } + } + + if len(blocks) == 0 { + blocks = append(blocks, AnthropicContentBlock{Type: "text", Text: ""}) + } + return json.Marshal(blocks) +} + +// fromResponsesCallIDToAnthropic converts an OpenAI function call ID back to +// Anthropic format. Reverses toResponsesCallID. +func fromResponsesCallIDToAnthropic(id string) string { + // If it has our "fc_" prefix wrapping a known Anthropic prefix, strip it + if after, ok := strings.CutPrefix(id, "fc_"); ok { + if strings.HasPrefix(after, "toolu_") || strings.HasPrefix(after, "call_") { + return after + } + } + // Generate a synthetic Anthropic tool ID + if !strings.HasPrefix(id, "toolu_") && !strings.HasPrefix(id, "call_") { + return "toolu_" + id + } + return id +} + +// dataURIToAnthropicImageSource parses a data URI into an AnthropicImageSource. +func dataURIToAnthropicImageSource(dataURI string) *AnthropicImageSource { + if !strings.HasPrefix(dataURI, "data:") { + return nil + } + // Format: data:;base64, + rest := strings.TrimPrefix(dataURI, "data:") + semicolonIdx := strings.Index(rest, ";") + if semicolonIdx < 0 { + return nil + } + mediaType := rest[:semicolonIdx] + rest = rest[semicolonIdx+1:] + if !strings.HasPrefix(rest, "base64,") { + return nil + } + data := strings.TrimPrefix(rest, "base64,") + return &AnthropicImageSource{ + Type: "base64", + MediaType: mediaType, + Data: data, + } +} + +// mergeConsecutiveMessages merges consecutive messages with the same role +// because Anthropic requires alternating user/assistant turns. +func mergeConsecutiveMessages(messages []AnthropicMessage) []AnthropicMessage { + if len(messages) <= 1 { + return messages + } + + var merged []AnthropicMessage + for _, msg := range messages { + if len(merged) == 0 || merged[len(merged)-1].Role != msg.Role { + merged = append(merged, msg) + continue + } + + // Same role — merge content arrays + last := &merged[len(merged)-1] + lastBlocks := parseContentBlocks(last.Content) + newBlocks := parseContentBlocks(msg.Content) + combined := append(lastBlocks, newBlocks...) + last.Content, _ = json.Marshal(combined) + } + return merged +} + +// parseContentBlocks attempts to parse content as []AnthropicContentBlock. +// If it's a string, wraps it in a text block. +func parseContentBlocks(raw json.RawMessage) []AnthropicContentBlock { + var blocks []AnthropicContentBlock + if err := json.Unmarshal(raw, &blocks); err == nil { + return blocks + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return []AnthropicContentBlock{{Type: "text", Text: s}} + } + return nil +} + +// convertResponsesToAnthropicTools maps Responses API tools to Anthropic format. +// Reverse of convertAnthropicToolsToResponses. +func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool { + var out []AnthropicTool + for _, t := range tools { + switch t.Type { + case "web_search": + out = append(out, AnthropicTool{ + Type: "web_search_20250305", + Name: "web_search", + }) + case "function": + out = append(out, AnthropicTool{ + Name: t.Name, + Description: t.Description, + InputSchema: normalizeAnthropicInputSchema(t.Parameters), + }) + default: + // Pass through unknown tool types + out = append(out, AnthropicTool{ + Type: t.Type, + Name: t.Name, + Description: t.Description, + InputSchema: t.Parameters, + }) + } + } + return out +} + +// normalizeAnthropicInputSchema ensures the input_schema has a "type" field. +func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage { + if len(schema) == 0 || string(schema) == "null" { + return json.RawMessage(`{"type":"object","properties":{}}`) + } + return schema +} + +// convertResponsesToAnthropicToolChoice maps Responses tool_choice to Anthropic format. +// Reverse of convertAnthropicToolChoiceToResponses. +// +// "auto" → {"type":"auto"} +// "required" → {"type":"any"} +// "none" → {"type":"none"} +// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} +func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try as string first + var s string + if err := json.Unmarshal(raw, &s); err == nil { + switch s { + case "auto": + return json.Marshal(map[string]string{"type": "auto"}) + case "required": + return json.Marshal(map[string]string{"type": "any"}) + case "none": + return json.Marshal(map[string]string{"type": "none"}) + default: + return raw, nil + } + } + + // Try as object with type=function + var tc struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + } + if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" { + return json.Marshal(map[string]string{ + "type": "tool", + "name": tc.Function.Name, + }) + } + + // Pass through unknown + return raw, nil +} diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index a35a5ea6..6b8521bd 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -270,6 +270,7 @@ type OpenAIAuthClaims struct { ChatGPTUserID string `json:"chatgpt_user_id"` ChatGPTPlanType string `json:"chatgpt_plan_type"` UserID string `json:"user_id"` + POID string `json:"poid"` // organization ID in access_token JWT Organizations []OrganizationClaim `json:"organizations"` } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 35b908de..d45e8a12 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -404,6 +404,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account return nil } +func (r *accountRepository) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + _, err := r.client.Account.UpdateOneID(id). + SetCredentials(normalizeJSONMap(credentials)). + Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + r.syncSchedulerAccountSnapshot(ctx, id) + return nil +} + func (r *accountRepository) Delete(ctx context.Context, id int64) error { groupIDs, err := r.loadAccountGroupIDs(ctx, id) if err != nil { @@ -443,10 +454,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "", 0) + return r.ListWithFilters(ctx, params, "", "", "", "", 0, "") } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -479,6 +490,20 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati } else if groupID > 0 { q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) } + if privacyMode != "" { + q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { + path := sqljson.Path("privacy_mode") + switch privacyMode { + case service.AccountPrivacyModeUnsetFilter: + s.Where(entsql.Or( + entsql.Not(sqljson.HasKey(dbaccount.FieldExtra, path)), + sqljson.ValueEQ(dbaccount.FieldExtra, "", path), + )) + default: + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, privacyMode, path)) + } + })) + } total, err := q.Count(ctx) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index d6f0e337..8da30c92 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -208,15 +208,16 @@ func (s *AccountRepoSuite) TestList() { func (s *AccountRepoSuite) TestListWithFilters() { tests := []struct { - name string - setup func(client *dbent.Client) - platform string - accType string - status string - search string - groupID int64 - wantCount int - validate func(accounts []service.Account) + name string + setup func(client *dbent.Client) + platform string + accType string + status string + search string + groupID int64 + privacyMode string + wantCount int + validate func(accounts []service.Account) }{ { name: "filter_by_platform", @@ -281,6 +282,32 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Require().Empty(accounts[0].GroupIDs) }, }, + { + name: "filter_by_privacy_mode", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-ok", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}}) + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-fail", Extra: map[string]any{"privacy_mode": service.PrivacyModeFailed}}) + }, + privacyMode: service.PrivacyModeTrainingOff, + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("privacy-ok", accounts[0].Name) + }, + }, + { + name: "filter_by_privacy_mode_unset", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-unset", Extra: nil}) + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-empty", Extra: map[string]any{"privacy_mode": ""}}) + mustCreateAccount(s.T(), client, &service.Account{Name: "privacy-set", Extra: map[string]any{"privacy_mode": service.PrivacyModeTrainingOff}}) + }, + privacyMode: service.AccountPrivacyModeUnsetFilter, + wantCount: 2, + validate: func(accounts []service.Account) { + names := []string{accounts[0].Name, accounts[1].Name} + s.ElementsMatch([]string{"privacy-unset", "privacy-empty"}, names) + }, + }, } for _, tt := range tests { @@ -293,7 +320,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, tt.groupID, tt.privacyMode) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -360,7 +387,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0, "") s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 02ca1a3b..5154b269 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -29,6 +29,11 @@ INSERT INTO ops_error_logs ( model, request_path, stream, + inbound_endpoint, + upstream_endpoint, + requested_model, + upstream_model, + request_type, user_agent, error_phase, error_type, @@ -57,7 +62,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43 )` func NewOpsRepository(db *sql.DB) service.OpsRepository { @@ -140,6 +145,11 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { opsNullString(input.Model), opsNullString(input.RequestPath), input.Stream, + opsNullString(input.InboundEndpoint), + opsNullString(input.UpstreamEndpoint), + opsNullString(input.RequestedModel), + opsNullString(input.UpstreamModel), + opsNullInt16(input.RequestType), opsNullString(input.UserAgent), input.ErrorPhase, input.ErrorType, @@ -231,7 +241,12 @@ SELECT COALESCE(g.name, ''), CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, COALESCE(e.request_path, ''), - e.stream + e.stream, + COALESCE(e.inbound_endpoint, ''), + COALESCE(e.upstream_endpoint, ''), + COALESCE(e.requested_model, ''), + COALESCE(e.upstream_model, ''), + e.request_type FROM ops_error_logs e LEFT JOIN accounts a ON e.account_id = a.id LEFT JOIN groups g ON e.group_id = g.id @@ -263,6 +278,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) var resolvedBy sql.NullInt64 var resolvedByName string var resolvedRetryID sql.NullInt64 + var requestType sql.NullInt64 if err := rows.Scan( &item.ID, &item.CreatedAt, @@ -294,6 +310,11 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) &clientIP, &item.RequestPath, &item.Stream, + &item.InboundEndpoint, + &item.UpstreamEndpoint, + &item.RequestedModel, + &item.UpstreamModel, + &requestType, ); err != nil { return nil, err } @@ -334,6 +355,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) item.GroupID = &v } item.GroupName = groupName + if requestType.Valid { + v := int16(requestType.Int64) + item.RequestType = &v + } out = append(out, &item) } if err := rows.Err(); err != nil { @@ -393,6 +418,11 @@ SELECT CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END, COALESCE(e.request_path, ''), e.stream, + COALESCE(e.inbound_endpoint, ''), + COALESCE(e.upstream_endpoint, ''), + COALESCE(e.requested_model, ''), + COALESCE(e.upstream_model, ''), + e.request_type, COALESCE(e.user_agent, ''), e.auth_latency_ms, e.routing_latency_ms, @@ -427,6 +457,7 @@ LIMIT 1` var responseLatency sql.NullInt64 var ttft sql.NullInt64 var requestBodyBytes sql.NullInt64 + var requestType sql.NullInt64 err := r.db.QueryRowContext(ctx, q, id).Scan( &out.ID, @@ -464,6 +495,11 @@ LIMIT 1` &clientIP, &out.RequestPath, &out.Stream, + &out.InboundEndpoint, + &out.UpstreamEndpoint, + &out.RequestedModel, + &out.UpstreamModel, + &requestType, &out.UserAgent, &authLatency, &routingLatency, @@ -540,6 +576,10 @@ LIMIT 1` v := int(requestBodyBytes.Int64) out.RequestBodyBytes = &v } + if requestType.Valid { + v := int16(requestType.Int64) + out.RequestType = &v + } // Normalize request_body to empty string when stored as JSON null. out.RequestBody = strings.TrimSpace(out.RequestBody) @@ -1479,3 +1519,10 @@ func opsNullInt(v any) any { return sql.NullInt64{} } } + +func opsNullInt16(v *int16) any { + if v == nil { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: int64(*v), Valid: true} +} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index a6bd50ac..880fe8a0 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -540,7 +540,8 @@ func TestAPIContracts(t *testing.T) { "max_claude_code_version": "", "allow_ungrouped_key_scheduling": false, "backend_mode_enabled": false, - "custom_menu_items": [] + "custom_menu_items": [], + "custom_endpoints": [] } }`, }, @@ -989,7 +990,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index fe820830..072cfdee 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -69,12 +69,30 @@ func RegisterGatewayRoutes( }) gateway.GET("/models", h.Gateway.Models) gateway.GET("/usage", h.Gateway.Usage) - // OpenAI Responses API - gateway.POST("/responses", h.OpenAIGateway.Responses) - gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) + // OpenAI Responses API: auto-route based on group platform + gateway.POST("/responses", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + }) + gateway.POST("/responses/*subpath", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + }) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) - // OpenAI Chat Completions API - gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) + // OpenAI Chat Completions API: auto-route based on group platform + gateway.POST("/chat/completions", func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.ChatCompletions(c) + return + } + h.Gateway.ChatCompletions(c) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -92,12 +110,25 @@ func RegisterGatewayRoutes( gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } - // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) - r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + // OpenAI Responses API(不带v1前缀的别名)— auto-route based on group platform + responsesHandler := func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.Responses(c) + return + } + h.Gateway.Responses(c) + } + r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) - // OpenAI Chat Completions API(不带v1前缀的别名) - r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) + // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { + if getGroupPlatform(c) == service.PlatformOpenAI { + h.OpenAIGateway.ChatCompletions(c) + return + } + h.Gateway.ChatCompletions(c) + }) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) diff --git a/backend/internal/service/account_credentials_persistence.go b/backend/internal/service/account_credentials_persistence.go new file mode 100644 index 00000000..916df536 --- /dev/null +++ b/backend/internal/service/account_credentials_persistence.go @@ -0,0 +1,30 @@ +package service + +import "context" + +type accountCredentialsUpdater interface { + UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error +} + +func persistAccountCredentials(ctx context.Context, repo AccountRepository, account *Account, credentials map[string]any) error { + if repo == nil || account == nil { + return nil + } + + account.Credentials = cloneCredentials(credentials) + if updater, ok := any(repo).(accountCredentialsUpdater); ok { + return updater.UpdateCredentials(ctx, account.ID, account.Credentials) + } + return repo.Update(ctx, account) +} + +func cloneCredentials(in map[string]any) map[string]any { + if in == nil { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 2e91db6b..71d51712 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -15,6 +15,7 @@ var ( ) const AccountListGroupUngrouped int64 = -1 +const AccountPrivacyModeUnsetFilter = "__unset__" type AccountRepository interface { Create(ctx context.Context, account *Account) error @@ -37,7 +38,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index c96b436f..81169a02 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index ccd681a3..ed85ee34 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -54,7 +54,7 @@ type AdminService interface { ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -1451,9 +1451,9 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) if err != nil { return nil, 0, err } diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index ff58fd01..eb213e6a 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -19,18 +19,20 @@ type accountRepoStubForAdminList struct { listWithFiltersType string listWithFiltersStatus string listWithFiltersSearch string + listWithFiltersPrivacy string listWithFiltersAccounts []Account listWithFiltersResult *pagination.PaginationResult listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform s.listWithFiltersType = accountType s.listWithFiltersStatus = status s.listWithFiltersSearch = search + s.listWithFiltersPrivacy = privacyMode if s.listWithFiltersErr != nil { return nil, nil, s.listWithFiltersErr @@ -168,7 +170,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) @@ -182,6 +184,22 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { }) } +func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { + t.Run("privacy_mode 参数正常传递到 repository 层", func(t *testing.T) { + repo := &accountRepoStubForAdminList{ + listWithFiltersAccounts: []Account{{ID: 2, Name: "acc2"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 1}, + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) + require.Equal(t, PrivacyModeCFBlocked, repo.listWithFiltersPrivacy) + }) +} + func TestAdminService_ListProxies_WithSearch(t *testing.T) { t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { repo := &proxyRepoStubForAdminList{ diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 6ee8280c..aa5d948c 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -643,6 +643,7 @@ urlFallbackLoop: AccountID: p.account.ID, AccountName: p.account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -720,6 +721,7 @@ urlFallbackLoop: AccountName: p.account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: upstreamMsg, Detail: getUpstreamDetail(respBody), @@ -754,6 +756,7 @@ urlFallbackLoop: AccountName: p.account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: upstreamMsg, Detail: getUpstreamDetail(respBody), diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 5e53f434..1b360d93 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -138,7 +138,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * p.markBackfillAttempted(account.ID) if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { account.Credentials["project_id"] = projectID - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + if updateErr := persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials); updateErr != nil { slog.Warn("antigravity_project_id_backfill_persist_failed", "account_id", account.ID, "error", updateErr, diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 6a916740..b69b0639 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -367,8 +367,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after creation if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } } item.Action = "created" @@ -402,8 +401,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after update if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } } @@ -620,8 +618,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } // 🔄 Refresh OAuth token after creation if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } item.Action = "created" result.Created++ @@ -652,8 +649,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput // 🔄 Refresh OAuth token after update if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } item.Action = "updated" @@ -862,8 +858,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { - account.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, s.accountRepo, account, refreshedCreds) } item.Action = "created" result.Created++ @@ -893,8 +888,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { - existing.Credentials = refreshedCreds - _ = s.accountRepo.Update(ctx, existing) + _ = persistAccountCredentials(ctx, s.accountRepo, existing, refreshedCreds) } item.Action = "updated" diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 384d5159..4ae5a469 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -119,6 +119,7 @@ const ( SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) + SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 44edf7f7..00691233 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -12,6 +12,7 @@ import ( "net/smtp" "net/url" "strconv" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -111,7 +112,7 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { return nil, fmt.Errorf("get smtp settings: %w", err) } - host := settings[SettingKeySMTPHost] + host := strings.TrimSpace(settings[SettingKeySMTPHost]) if host == "" { return nil, ErrEmailNotConfigured } @@ -128,10 +129,10 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) { return &SMTPConfig{ Host: host, Port: port, - Username: settings[SettingKeySMTPUsername], - Password: settings[SettingKeySMTPPassword], - From: settings[SettingKeySMTPFrom], - FromName: settings[SettingKeySMTPFromName], + Username: strings.TrimSpace(settings[SettingKeySMTPUsername]), + Password: strings.TrimSpace(settings[SettingKeySMTPPassword]), + From: strings.TrimSpace(settings[SettingKeySMTPFrom]), + FromName: strings.TrimSpace(settings[SettingKeySMTPFromName]), UseTLS: useTLS, }, nil } diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go new file mode 100644 index 00000000..d3c611e2 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -0,0 +1,485 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts an OpenAI Chat Completions API request body, +// converts it to Anthropic Messages format (chained via Responses format), +// forwards to the Anthropic upstream, and converts the response back to Chat +// Completions format. This enables Chat Completions clients to access Anthropic +// models through Anthropic platform groups. +func (s *GatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var ccReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &ccReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := ccReq.Model + clientStream := ccReq.Stream + includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage + + // 2. Convert CC → Responses → Anthropic (chained conversion) + responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq) + if err != nil { + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + + // 3. Force upstream streaming + anthropicReq.Stream = true + reqStream := true + + // 4. Model mapping + mappedModel := originalModel + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(originalModel) + if normalized != originalModel { + mappedModel = normalized + } + } + anthropicReq.Model = mappedModel + + logger.L().Debug("gateway forward_as_chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("client_stream", clientStream), + ) + + // 5. Marshal Anthropic request body + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + // 6. Apply Claude Code mimicry for OAuth accounts + isClaudeCode := false // CC API is never Claude Code + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + if !strings.Contains(strings.ToLower(mappedModel), "haiku") && + !systemIncludesClaudeCodePrompt(anthropicReq.System) { + anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) + } + } + + // 7. Enforce cache_control block limit + anthropicBody = enforceCacheControlLimit(anthropicBody) + + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 12. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + writeGatewayCCError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 13. Extract reasoning effort from CC request body + reasoningEffort := extractCCReasoningEffortFromBody(body) + + // 14. Handle normal response + // Read Anthropic SSE → convert to Responses events → convert to CC format + var result *ForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage) + } else { + result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } + + return result, handleErr +} + +// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions +// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort) +// formats used by OpenAI-compatible clients. +func extractCCReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +// handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full +// response, then converts Anthropic → Responses → Chat Completions. +func (s *GatewayService) handleCCBufferedFromAnthropic( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResp *apicompat.AnthropicResponse + var usage ClaudeUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + // message_start carries the initial response structure and cache usage + if event.Type == "message_start" && event.Message != nil { + finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // message_delta carries final usage and stop_reason + if event.Type == "message_delta" { + if event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { + finalResp.StopReason = event.Delta.StopReason + } + } + if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil { + finalResp.Content = append(finalResp.Content, *event.ContentBlock) + } + if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil { + idx := *event.Index + if idx < len(finalResp.Content) { + switch event.Delta.Type { + case "text_delta": + finalResp.Content[idx].Text += event.Delta.Text + case "thinking_delta": + finalResp.Content[idx].Thinking += event.Delta.Thinking + case "input_json_delta": + finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON) + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_cc buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResp == nil { + writeGatewayCCError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response") + return nil, fmt.Errorf("upstream stream ended without response") + } + + // Update usage from accumulated delta + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + finalResp.Usage = apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + // Chain: Anthropic → Responses → Chat Completions + responsesResp := apicompat.AnthropicToResponsesResponse(finalResp) + ccResp := apicompat.ResponsesToChatCompletions(responsesResp, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, ccResp) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleCCStreamingFromAnthropic reads Anthropic SSE events, converts each +// to Responses events, then to Chat Completions chunks, and writes them. +func (s *GatewayService) handleCCStreamingFromAnthropic( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, + includeUsage bool, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + // Use Anthropic→Responses state machine, then convert Responses→CC + anthState := apicompat.NewAnthropicEventToResponsesState() + anthState.Model = originalModel + ccState := apicompat.NewResponsesEventToChatState() + ccState.Model = originalModel + ccState.IncludeUsage = includeUsage + + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *ForwardResult { + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + writeChunk := func(chunk apicompat.ChatCompletionsChunk) bool { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + return false + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + return true // client disconnected + } + return false + } + + processAnthropicEvent := func(event *apicompat.AnthropicStreamEvent) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // Extract usage from message_delta + if event.Type == "message_delta" && event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + // Also capture usage from message_start (carries cache fields) + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // Chain: Anthropic event → Responses events → CC chunks + responsesEvents := apicompat.AnthropicEventToResponsesEvents(event, anthState) + for _, resEvt := range responsesEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + if disconnected := writeChunk(chunk); disconnected { + return true + } + } + } + c.Writer.Flush() + return false + } + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + if processAnthropicEvent(&event) { + return resultWithUsage(), nil + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_cc stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Finalize both state machines + finalResEvents := apicompat.FinalizeAnthropicResponsesStream(anthState) + for _, resEvt := range finalResEvents { + ccChunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range ccChunks { + writeChunk(chunk) //nolint:errcheck + } + } + finalCCChunks := apicompat.FinalizeResponsesChatStream(ccState) + for _, chunk := range finalCCChunks { + writeChunk(chunk) //nolint:errcheck + } + + // Write [DONE] marker + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + + return resultWithUsage(), nil +} + +// writeGatewayCCError writes an error in OpenAI Chat Completions format for +// the Anthropic-upstream CC forwarding path. +func writeGatewayCCError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/gateway_forward_as_chat_completions_test.go b/backend/internal/service/gateway_forward_as_chat_completions_test.go new file mode 100644 index 00000000..5003e5b3 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractCCReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + t.Run("nested reasoning.effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + }) + + t.Run("flat reasoning_effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`)) + require.NotNil(t, got) + require.Equal(t, "xhigh", *got) + }) + + t.Run("missing effort", func(t *testing.T) { + require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`))) + }) +} + +func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "high" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) +} + +func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "medium" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "medium", *result.ReasoningEffort) + require.Contains(t, rec.Body.String(), `[DONE]`) +} diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go new file mode 100644 index 00000000..5dca57f9 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -0,0 +1,518 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ForwardAsResponses accepts an OpenAI Responses API request body, converts it +// to Anthropic Messages format, forwards to the Anthropic upstream, and converts +// the response back to Responses format. This enables OpenAI Responses API +// clients to access Anthropic models through Anthropic platform groups. +// +// The method follows the same pattern as OpenAIGatewayService.ForwardAsAnthropic +// but in reverse direction: Responses → Anthropic upstream → Responses. +func (s *GatewayService) ForwardAsResponses( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + parsed *ParsedRequest, +) (*ForwardResult, error) { + startTime := time.Now() + + // 1. Parse Responses request + var responsesReq apicompat.ResponsesRequest + if err := json.Unmarshal(body, &responsesReq); err != nil { + return nil, fmt.Errorf("parse responses request: %w", err) + } + originalModel := responsesReq.Model + clientStream := responsesReq.Stream + + // 2. Convert Responses → Anthropic + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(&responsesReq) + if err != nil { + return nil, fmt.Errorf("convert responses to anthropic: %w", err) + } + + // 3. Force upstream streaming (Anthropic works best with streaming) + anthropicReq.Stream = true + reqStream := true + + // 4. Model mapping + mappedModel := originalModel + reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(originalModel) + } + if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(originalModel) + if normalized != originalModel { + mappedModel = normalized + } + } + anthropicReq.Model = mappedModel + + logger.L().Debug("gateway forward_as_responses: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("client_stream", clientStream), + ) + + // 5. Marshal Anthropic request body + anthropicBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal anthropic request: %w", err) + } + + // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints) + isClaudeCode := false // Responses API is never Claude Code + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + if !strings.Contains(strings.ToLower(mappedModel), "haiku") && + !systemIncludesClaudeCodePrompt(anthropicReq.System) { + anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System) + } + } + + // 7. Enforce cache_control block limit + anthropicBody = enforceCacheControlLimit(anthropicBody) + + // 8. Get access token + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 9. Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 10. Build upstream request + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, anthropicBody, token, tokenType, mappedModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + // 11. Send request + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 12. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + // Non-failover error: return Responses-formatted error to client + writeResponsesError(c, mapUpstreamStatusCode(resp.StatusCode), "server_error", upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + } + + // 13. Handle normal response (convert Anthropic → Responses) + var result *ForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } else { + result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) + } + + return result, handleErr +} + +// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort +// and normalizes it for usage logging. +func ExtractResponsesReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) { + if dst == nil { + return + } + if src.InputTokens > 0 { + dst.InputTokens = src.InputTokens + } + if src.OutputTokens > 0 { + dst.OutputTokens = src.OutputTokens + } + if src.CacheReadInputTokens > 0 { + dst.CacheReadInputTokens = src.CacheReadInputTokens + } + if src.CacheCreationInputTokens > 0 { + dst.CacheCreationInputTokens = src.CacheCreationInputTokens + } +} + +// handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from +// the upstream streaming response, assembles them into a complete Anthropic +// response, converts to Responses API JSON format, and writes it to the client. +func (s *GatewayService) handleResponsesBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + // Accumulate the final Anthropic response from streaming events + var finalResp *apicompat.AnthropicResponse + var usage ClaudeUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + eventType := strings.TrimPrefix(line, "event: ") + + // Read the data line + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("forward_as_responses buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + zap.String("event_type", eventType), + ) + continue + } + + // message_start carries the initial response structure + if event.Type == "message_start" && event.Message != nil { + finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // message_delta carries final usage and stop_reason + if event.Type == "message_delta" { + if event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { + finalResp.StopReason = event.Delta.StopReason + } + } + + // Accumulate content blocks + if event.Type == "content_block_start" && event.ContentBlock != nil && finalResp != nil { + finalResp.Content = append(finalResp.Content, *event.ContentBlock) + } + if event.Type == "content_block_delta" && event.Delta != nil && finalResp != nil && event.Index != nil { + idx := *event.Index + if idx < len(finalResp.Content) { + switch event.Delta.Type { + case "text_delta": + finalResp.Content[idx].Text += event.Delta.Text + case "thinking_delta": + finalResp.Content[idx].Thinking += event.Delta.Thinking + case "input_json_delta": + finalResp.Content[idx].Input = appendRawJSON(finalResp.Content[idx].Input, event.Delta.PartialJSON) + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_responses buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResp == nil { + writeResponsesError(c, http.StatusBadGateway, "server_error", "Upstream stream ended without a response") + return nil, fmt.Errorf("upstream stream ended without response") + } + + // Update usage from accumulated delta + if usage.InputTokens > 0 || usage.OutputTokens > 0 { + finalResp.Usage = apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + // Convert to Responses format + responsesResp := apicompat.AnthropicToResponsesResponse(finalResp) + responsesResp.Model = originalModel // Use original model name + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, responsesResp) + + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleResponsesStreamingResponse reads Anthropic SSE events from upstream, +// converts each to Responses SSE events, and writes them to the client. +func (s *GatewayService) handleResponsesStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + reasoningEffort *string, + startTime time.Time, +) (*ForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewAnthropicEventToResponsesState() + state.Model = originalModel + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *ForwardResult { + return &ForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + // processEvent handles a single parsed Anthropic SSE event. + processEvent := func(event *apicompat.AnthropicStreamEvent) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // Extract usage from message_delta + if event.Type == "message_delta" && event.Usage != nil { + mergeAnthropicUsage(&usage, *event.Usage) + } + // Also capture usage from message_start + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) + } + + // Convert to Responses events + events := apicompat.AnthropicEventToResponsesEvents(event, state) + for _, evt := range events { + sse, err := apicompat.ResponsesEventToSSE(evt) + if err != nil { + logger.L().Warn("forward_as_responses stream: failed to marshal event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("forward_as_responses stream: client disconnected", + zap.String("request_id", requestID), + ) + return true // client disconnected + } + } + if len(events) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*ForwardResult, error) { + if finalEvents := apicompat.FinalizeAnthropicResponsesStream(state); len(finalEvents) > 0 { + for _, evt := range finalEvents { + sse, err := apicompat.ResponsesEventToSSE(evt) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + c.Writer.Flush() + } + return resultWithUsage(), nil + } + + // Read Anthropic SSE events + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "event: ") { + continue + } + eventType := strings.TrimPrefix(line, "event: ") + + // Read data line + if !scanner.Scan() { + break + } + dataLine := scanner.Text() + if !strings.HasPrefix(dataLine, "data: ") { + continue + } + payload := dataLine[6:] + + var event apicompat.AnthropicStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("forward_as_responses stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + zap.String("event_type", eventType), + ) + continue + } + + if processEvent(&event) { + return resultWithUsage(), nil + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("forward_as_responses stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + return finalizeStream() +} + +// appendRawJSON appends a JSON fragment string to existing raw JSON. +func appendRawJSON(existing json.RawMessage, fragment string) json.RawMessage { + if len(existing) == 0 { + return json.RawMessage(fragment) + } + return json.RawMessage(string(existing) + fragment) +} + +// writeResponsesError writes an error response in OpenAI Responses API format. +func writeResponsesError(c *gin.Context, statusCode int, code, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "code": code, + "message": message, + }, + }) +} + +// mapUpstreamStatusCode maps upstream HTTP status codes to appropriate client-facing codes. +func mapUpstreamStatusCode(code int) int { + if code >= 500 { + return http.StatusBadGateway + } + return code +} diff --git a/backend/internal/service/gateway_forward_as_responses_test.go b/backend/internal/service/gateway_forward_as_responses_test.go new file mode 100644 index 00000000..e48d8b22 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses_test.go @@ -0,0 +1,94 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractResponsesReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + + require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`))) +} + +func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `"cached_tokens":9`) +} + +func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `response.completed`) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 718cd42a..f28912bb 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 29b6cfd6..e2badfed 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,6 +5,8 @@ import ( "encoding/json" "fmt" "math" + "regexp" + "sort" "strings" "unsafe" @@ -34,6 +36,9 @@ var ( patternEmptyTextSpaced = []byte(`"text": ""`) patternEmptyTextSp1 = []byte(`"text" : ""`) patternEmptyTextSp2 = []byte(`"text" :""`) + + sessionUserAgentProductPattern = regexp.MustCompile(`([A-Za-z0-9._-]+)/[A-Za-z0-9._-]+`) + sessionUserAgentVersionPattern = regexp.MustCompile(`\bv?\d+(?:\.\d+){1,3}\b`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -75,6 +80,49 @@ type ParsedRequest struct { OnUpstreamAccepted func() } +// NormalizeSessionUserAgent reduces UA noise for sticky-session and digest hashing. +// It preserves the set of product names from Product/Version tokens while +// discarding version-only changes and incidental comments. +func NormalizeSessionUserAgent(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + + matches := sessionUserAgentProductPattern.FindAllStringSubmatch(raw, -1) + if len(matches) == 0 { + return normalizeSessionUserAgentFallback(raw) + } + + products := make([]string, 0, len(matches)) + seen := make(map[string]struct{}, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + product := strings.ToLower(strings.TrimSpace(match[1])) + if product == "" { + continue + } + if _, exists := seen[product]; exists { + continue + } + seen[product] = struct{}{} + products = append(products, product) + } + if len(products) == 0 { + return normalizeSessionUserAgentFallback(raw) + } + sort.Strings(products) + return strings.Join(products, "+") +} + +func normalizeSessionUserAgentFallback(raw string) string { + normalized := strings.ToLower(strings.Join(strings.Fields(raw), " ")) + normalized = sessionUserAgentVersionPattern.ReplaceAllString(normalized, "") + return strings.Join(strings.Fields(normalized), " ") +} + // ParseGatewayRequest 解析网关请求体并返回结构化结果。 // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // 不同协议使用不同的 system/messages 字段名。 @@ -205,6 +253,118 @@ func sliceRawFromBody(body []byte, r gjson.Result) []byte { return []byte(r.Raw) } +// stripEmptyTextBlocksFromSlice removes empty text blocks from a content slice (including nested tool_result content). +// Returns (cleaned slice, true) if any blocks were removed, or (original, false) if unchanged. +func stripEmptyTextBlocksFromSlice(blocks []any) ([]any, bool) { + var result []any + changed := false + for i, block := range blocks { + blockMap, ok := block.(map[string]any) + if !ok { + if result != nil { + result = append(result, block) + } + continue + } + blockType, _ := blockMap["type"].(string) + + // Strip empty text blocks + if blockType == "text" { + if txt, _ := blockMap["text"].(string); txt == "" { + if result == nil { + result = make([]any, 0, len(blocks)) + result = append(result, blocks[:i]...) + } + changed = true + continue + } + } + + // Recurse into tool_result nested content + if blockType == "tool_result" { + if nestedContent, ok := blockMap["content"].([]any); ok { + if cleaned, nestedChanged := stripEmptyTextBlocksFromSlice(nestedContent); nestedChanged { + if result == nil { + result = make([]any, 0, len(blocks)) + result = append(result, blocks[:i]...) + } + changed = true + blockCopy := make(map[string]any, len(blockMap)) + for k, v := range blockMap { + blockCopy[k] = v + } + blockCopy["content"] = cleaned + result = append(result, blockCopy) + continue + } + } + } + + if result != nil { + result = append(result, block) + } + } + if !changed { + return blocks, false + } + return result, true +} + +// StripEmptyTextBlocks removes empty text blocks from the request body (including nested tool_result content). +// This is a lightweight pre-filter for the initial request path to prevent upstream 400 errors. +// Returns the original body unchanged if no empty text blocks are found. +func StripEmptyTextBlocks(body []byte) []byte { + // Fast path: check if body contains empty text patterns + hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) || + bytes.Contains(body, patternEmptyTextSpaced) || + bytes.Contains(body, patternEmptyTextSp1) || + bytes.Contains(body, patternEmptyTextSp2) + if !hasEmptyTextBlock { + return body + } + + jsonStr := *(*string)(unsafe.Pointer(&body)) + msgsRes := gjson.Get(jsonStr, "messages") + if !msgsRes.Exists() || !msgsRes.IsArray() { + return body + } + + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil { + return body + } + + modified := false + for _, msg := range messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + content, ok := msgMap["content"].([]any) + if !ok { + continue + } + if cleaned, changed := stripEmptyTextBlocksFromSlice(content); changed { + modified = true + msgMap["content"] = cleaned + } + } + + if !modified { + return body + } + + msgsBytes, err := json.Marshal(messages) + if err != nil { + return body + } + out, err := sjson.SetRawBytes(body, "messages", msgsBytes) + if err != nil { + return body + } + return out +} + // FilterThinkingBlocks removes thinking blocks from request body // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures @@ -378,6 +538,23 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } } + // Recursively strip empty text blocks from tool_result nested content. + if blockType == "tool_result" { + if nestedContent, ok := blockMap["content"].([]any); ok { + if cleaned, changed := stripEmptyTextBlocksFromSlice(nestedContent); changed { + modifiedThisMsg = true + ensureNewContent(bi) + blockCopy := make(map[string]any, len(blockMap)) + for k, v := range blockMap { + blockCopy[k] = v + } + blockCopy["content"] = cleaned + newContent = append(newContent, blockCopy) + continue + } + } + } + if newContent != nil { newContent = append(newContent, block) } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index b11fee9b..d262456d 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -435,6 +435,122 @@ func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) { require.NotEmpty(t, block1["text"]) } +func TestFilterThinkingBlocksForRetry_StripsNestedEmptyTextInToolResult(t *testing.T) { + // Empty text blocks nested inside tool_result content should also be stripped + input := []byte(`{ + "messages":[ + {"role":"user","content":[ + {"type":"tool_result","tool_use_id":"t1","content":[ + {"type":"text","text":"valid result"}, + {"type":"text","text":""} + ]} + ]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + msg0 := msgs[0].(map[string]any) + content0 := msg0["content"].([]any) + require.Len(t, content0, 1) + toolResult := content0[0].(map[string]any) + require.Equal(t, "tool_result", toolResult["type"]) + nestedContent := toolResult["content"].([]any) + require.Len(t, nestedContent, 1) + require.Equal(t, "valid result", nestedContent[0].(map[string]any)["text"]) +} + +func TestFilterThinkingBlocksForRetry_NestedAllEmptyGetsEmptySlice(t *testing.T) { + // If all nested content blocks in tool_result are empty text, content becomes empty slice + input := []byte(`{ + "messages":[ + {"role":"user","content":[ + {"type":"tool_result","tool_use_id":"t1","content":[ + {"type":"text","text":""} + ]}, + {"type":"text","text":"hello"} + ]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + msg0 := msgs[0].(map[string]any) + content0 := msg0["content"].([]any) + require.Len(t, content0, 2) + toolResult := content0[0].(map[string]any) + nestedContent := toolResult["content"].([]any) + require.Len(t, nestedContent, 0) +} + +func TestStripEmptyTextBlocks(t *testing.T) { + t.Run("strips top-level empty text", func(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + require.Len(t, content, 1) + require.Equal(t, "hello", content[0].(map[string]any)["text"]) + }) + + t.Run("strips nested empty text in tool_result", func(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"text","text":"ok"},{"type":"text","text":""}]}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + toolResult := content[0].(map[string]any) + nestedContent := toolResult["content"].([]any) + require.Len(t, nestedContent, 1) + require.Equal(t, "ok", nestedContent[0].(map[string]any)["text"]) + }) + + t.Run("no-op when no empty text", func(t *testing.T) { + input := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + out := StripEmptyTextBlocks(input) + require.Equal(t, input, out) + }) + + t.Run("preserves non-map blocks in content", func(t *testing.T) { + // tool_result content can be a string; non-map blocks should pass through unchanged + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":"string content"},{"type":"text","text":""}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + require.Len(t, content, 1) + toolResult := content[0].(map[string]any) + require.Equal(t, "tool_result", toolResult["type"]) + require.Equal(t, "string content", toolResult["content"]) + }) + + t.Run("handles deeply nested tool_result", func(t *testing.T) { + // Recursive: tool_result containing another tool_result with empty text + input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_result","tool_use_id":"t2","content":[{"type":"text","text":""},{"type":"text","text":"deep"}]}]}]}]}`) + out := StripEmptyTextBlocks(input) + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs := req["messages"].([]any) + content := msgs[0].(map[string]any)["content"].([]any) + outer := content[0].(map[string]any) + innerContent := outer["content"].([]any) + inner := innerContent[0].(map[string]any) + deepContent := inner["content"].([]any) + require.Len(t, deepContent, 1) + require.Equal(t, "deep", deepContent[0].(map[string]any)["text"]) + }) +} + func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) { // Non-empty text blocks should pass through unchanged input := []byte(`{ diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 72cef2ac..402975d7 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -658,7 +658,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { if parsed.SessionContext != nil { _, _ = combined.WriteString(parsed.SessionContext.ClientIP) _, _ = combined.WriteString(":") - _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent)) _, _ = combined.WriteString(":") _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) _, _ = combined.WriteString("|") @@ -4119,6 +4119,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 调试日志:记录即将转发的账号信息 logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. + body = StripEmptyTextBlocks(body) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 setOpsUpstreamRequestBody(c, body) @@ -4148,6 +4151,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -4174,6 +4178,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "signature_error", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -4228,6 +4233,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: retryResp.StatusCode, UpstreamRequestID: retryResp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(retryReq.URL.String()), Kind: "signature_retry_thinking", Message: extractUpstreamErrorMessage(retryRespBody), Detail: func() string { @@ -4258,6 +4264,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(retryReq2.URL.String()), Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) @@ -4297,6 +4304,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "budget_constraint_error", Message: errMsg, Detail: func() string { @@ -4358,6 +4366,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -4603,6 +4612,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( if c != nil { c.Set("anthropic_passthrough", true) } + // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. + input.Body = StripEmptyTextBlocks(input.Body) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 setOpsUpstreamRequestBody(c, input.Body) @@ -4628,6 +4640,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "request_error", Message: safeErr, @@ -4667,6 +4680,7 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "retry", Message: extractUpstreamErrorMessage(respBody), @@ -5344,6 +5358,7 @@ func (s *GatewayService) executeBedrockUpstream( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "request_error", Message: safeErr, }) @@ -5380,6 +5395,7 @@ func (s *GatewayService) executeBedrockUpstream( AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { @@ -7877,6 +7893,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body := parsed.Body reqModel := parsed.Model + // Pre-filter: strip empty text blocks to prevent upstream 400. + body = StripEmptyTextBlocks(body) + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode @@ -8064,6 +8083,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "request_error", Message: sanitizeUpstreamErrorMessage(err.Error()), @@ -8119,6 +8139,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, Kind: "http_error", Message: upstreamMsg, diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index a78c56e7..5e09b95a 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 1780d1da..cd291328 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -52,10 +52,11 @@ func BuildGeminiDigestChain(req *antigravity.GeminiRequest) string { // 返回 16 字符的 Base64 编码的 SHA256 前缀 func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, model string) string { // 组合所有标识符 + normalizedUserAgent := NormalizeSessionUserAgent(userAgent) combined := strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(apiKeyID, 10) + ":" + ip + ":" + - userAgent + ":" + + normalizedUserAgent + ":" + platform + ":" + model diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index a034cddd..27321996 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,6 +152,24 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } +func TestGenerateGeminiPrefixHash_IgnoresUserAgentVersionNoise(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Mozilla/5.0 codex_cli_rs/0.1.1", "antigravity", "gemini-2.5-pro") + + if hash1 != hash2 { + t.Fatalf("version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2) + } +} + +func TestGenerateGeminiPrefixHash_IgnoresFreeformUserAgentVersionNoise(t *testing.T) { + hash1 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.0", "antigravity", "gemini-2.5-pro") + hash2 := GenerateGeminiPrefixHash(1, 100, "192.168.1.1", "Codex CLI 0.1.1", "antigravity", "gemini-2.5-pro") + + if hash1 != hash2 { + t.Fatalf("free-form version-only User-Agent changes should not perturb Gemini prefix hash: %s vs %s", hash1, hash2) + } +} + func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 1dab67c4..7add3460 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -135,7 +135,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou if tierID != "" { account.Credentials["tier_id"] = tierID } - _ = p.accountRepo.Update(ctx, account) + _ = persistAccountCredentials(ctx, p.accountRepo, account, account.Credentials) } } diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index f91fb4c9..39679c3d 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -504,6 +504,48 @@ func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") } +func TestGenerateSessionHash_SessionContext_UAVersionNoiseIgnored(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.0")) + h2 := svc.GenerateSessionHash(base("Mozilla/5.0 codex_cli_rs/0.1.1")) + require.Equal(t, h1, h2, "version-only User-Agent changes should not perturb the sticky session hash") +} + +func TestGenerateSessionHash_SessionContext_FreeformUAVersionNoiseIgnored(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Codex CLI 0.1.0")) + h2 := svc.GenerateSessionHash(base("Codex CLI 0.1.1")) + require.Equal(t, h1, h2, "free-form version-only User-Agent changes should not perturb the sticky session hash") +} + func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { svc := &GatewayService{} diff --git a/backend/internal/service/oauth_refresh_api.go b/backend/internal/service/oauth_refresh_api.go index 17b9128c..5dbba638 100644 --- a/backend/internal/service/oauth_refresh_api.go +++ b/backend/internal/service/oauth_refresh_api.go @@ -108,8 +108,7 @@ func (api *OAuthRefreshAPI) RefreshIfNeeded( // 5. 设置版本号 + 更新 DB if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() - freshAccount.Credentials = newCredentials - if updateErr := api.accountRepo.Update(ctx, freshAccount); updateErr != nil { + if updateErr := persistAccountCredentials(ctx, api.accountRepo, freshAccount, newCredentials); updateErr != nil { slog.Error("oauth_refresh_update_failed", "account_id", freshAccount.ID, "error", updateErr, diff --git a/backend/internal/service/oauth_refresh_api_test.go b/backend/internal/service/oauth_refresh_api_test.go index 6cf9371f..c3b38ddf 100644 --- a/backend/internal/service/oauth_refresh_api_test.go +++ b/backend/internal/service/oauth_refresh_api_test.go @@ -16,10 +16,11 @@ import ( // refreshAPIAccountRepo implements AccountRepository for OAuthRefreshAPI tests. type refreshAPIAccountRepo struct { mockAccountRepoForGemini - account *Account // returned by GetByID - getByIDErr error - updateErr error - updateCalls int + account *Account // returned by GetByID + getByIDErr error + updateErr error + updateCalls int + updateCredentialsCalls int } func (r *refreshAPIAccountRepo) GetByID(_ context.Context, _ int64) (*Account, error) { @@ -34,6 +35,19 @@ func (r *refreshAPIAccountRepo) Update(_ context.Context, _ *Account) error { return r.updateErr } +func (r *refreshAPIAccountRepo) UpdateCredentials(_ context.Context, id int64, credentials map[string]any) error { + r.updateCalls++ + r.updateCredentialsCalls++ + if r.updateErr != nil { + return r.updateErr + } + if r.account == nil || r.account.ID != id { + r.account = &Account{ID: id} + } + r.account.Credentials = cloneCredentials(credentials) + return nil +} + // refreshAPIExecutorStub implements OAuthRefreshExecutor for tests. type refreshAPIExecutorStub struct { needsRefresh bool @@ -106,10 +120,36 @@ func TestRefreshIfNeeded_Success(t *testing.T) { require.Equal(t, "new-token", result.NewCredentials["access_token"]) require.NotNil(t, result.NewCredentials["_token_version"]) // version stamp set require.Equal(t, 1, repo.updateCalls) // DB updated - require.Equal(t, 1, cache.releaseCalls) // lock released + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 1, cache.releaseCalls) // lock released require.Equal(t, 1, executor.refreshCalls) } +func TestRefreshIfNeeded_UpdateCredentialsPreservesRateLimitState(t *testing.T) { + resetAt := time.Now().Add(45 * time.Minute) + account := &Account{ + ID: 11, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + RateLimitResetAt: &resetAt, + } + repo := &refreshAPIAccountRepo{account: account} + cache := &refreshAPICacheStub{lockResult: true} + executor := &refreshAPIExecutorStub{ + needsRefresh: true, + credentials: map[string]any{"access_token": "safe-token"}, + } + + api := NewOAuthRefreshAPI(repo, cache) + result, err := api.RefreshIfNeeded(context.Background(), account, executor, 3*time.Minute) + + require.NoError(t, err) + require.True(t, result.Refreshed) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.NotNil(t, repo.account.RateLimitResetAt) + require.WithinDuration(t, resetAt, *repo.account.RateLimitResetAt, time.Second) +} + func TestRefreshIfNeeded_LockHeld(t *testing.T) { account := &Account{ID: 2, Platform: PlatformAnthropic} repo := &refreshAPIAccountRepo{account: account} @@ -193,7 +233,7 @@ func TestRefreshIfNeeded_RefreshError(t *testing.T) { require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "invalid_grant") - require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error + require.Equal(t, 0, repo.updateCalls) // no DB update on refresh error require.Equal(t, 1, cache.releaseCalls) // lock still released via defer } @@ -299,8 +339,8 @@ func TestMergeCredentials_NewOverridesOld(t *testing.T) { result := MergeCredentials(old, new) - require.Equal(t, "new-token", result["access_token"]) // overridden - require.Equal(t, "old-refresh", result["refresh_token"]) // preserved + require.Equal(t, "new-token", result["access_token"]) // overridden + require.Equal(t, "old-refresh", result["refresh_token"]) // preserved } // ========== BuildClaudeAccountCredentials tests ========== diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 789888cb..37e7ed2c 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -330,6 +330,11 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } + account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel) + if account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { @@ -691,6 +696,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { continue } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) { + continue + } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if acquireErr != nil { return nil, len(candidates), topK, loadSkew, acquireErr diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 977c4ee8..088815ed 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -84,6 +84,61 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa require.Equal(t, int64(32002), account.ID) } +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeRecheckSkipsStaleCachedAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(10103) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleSticky := &Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{staleSticky, staleBackup}, + accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, + } + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + cache: cache, + cfg: &config.Config{}, + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(33002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} + +func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeRecheckSkipsStaleCachedCandidate(t *testing.T) { + ctx := context.Background() + groupID := int64(10104) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + stalePrimary := &Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0} + staleSecondary := &Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + dbPrimary := Account{ID: 34001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} + dbSecondary := Account{ID: 34002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{stalePrimary, staleSecondary}, + accountsByID: map[int64]*Account{34001: stalePrimary, 34002: staleSecondary}, + } + snapshotService := &SchedulerSnapshotService{cache: snapshotCache} + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + cfg: &config.Config{}, + schedulerSnapshot: snapshotService, + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(34002), account.ID) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 4e96cf05..a72a86ac 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1201,6 +1201,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account @@ -1229,6 +1234,10 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ if fresh == nil { continue } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel) + if fresh == nil { + continue + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used @@ -1353,27 +1362,32 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } @@ -1560,6 +1574,28 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. return fresh } +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account { + if account == nil { + return nil + } + if s.schedulerSnapshot == nil || s.accountRepo == nil { + return account + } + + latest, err := s.accountRepo.GetByID(ctx, account.ID) + if err != nil || latest == nil { + return nil + } + syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, latest, time.Now()) + if !latest.IsSchedulable() || !latest.IsOpenAI() { + return nil + } + if requestedModel != "" && !latest.IsModelSupported(requestedModel) { + return nil + } + return latest +} + func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { var ( account *Account @@ -2598,6 +2634,12 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + if s.rateLimitService != nil { + // Passthrough mode preserves the raw upstream error response, but runtime + // account state still needs to be updated so sticky routing can stop + // reusing a freshly rate-limited account. + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index f51a7491..fe639576 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -536,6 +536,55 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF require.True(t, arr[len(arr)-1].Passthrough) } +func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + resetAt := time.Now().Add(7 * 24 * time.Hour).Unix() + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-rate-limit"}, + }, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))), + } + upstream := &httpUpstreamRecorder{resp: resp} + repo := &openAIWSRateLimitSignalRepo{} + rateSvc := &RateLimitService{accountRepo: repo} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + rateLimitService: rateSvc, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Contains(t, rec.Body.String(), "usage_limit_reached") + require.Len(t, repo.rateLimitCalls, 1) + require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second) +} + func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index bd82e107..0a1266d9 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -29,9 +29,10 @@ type soraSessionChunk struct { // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { - sessionStore *openai.SessionStore - proxyRepo ProxyRepository - oauthClient OpenAIOAuthClient + sessionStore *openai.SessionStore + proxyRepo ProxyRepository + oauthClient OpenAIOAuthClient + privacyClientFactory PrivacyClientFactory // 用于调用 chatgpt.com/backend-api(ImpersonateChrome) } // NewOpenAIOAuthService creates a new OpenAI OAuth service @@ -43,6 +44,12 @@ func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthCli } } +// SetPrivacyClientFactory 注入 ImpersonateChrome 客户端工厂, +// 用于调用 chatgpt.com/backend-api 获取账号信息(plan_type 等)。 +func (s *OpenAIOAuthService) SetPrivacyClientFactory(factory PrivacyClientFactory) { + s.privacyClientFactory = factory +} + // OpenAIAuthURLResult contains the authorization URL and session info type OpenAIAuthURLResult struct { AuthURL string `json:"auth_url"` @@ -131,6 +138,7 @@ type OpenAITokenInfo struct { ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` OrganizationID string `json:"organization_id,omitempty"` PlanType string `json:"plan_type,omitempty"` + PrivacyMode string `json:"privacy_mode,omitempty"` } // ExchangeCode exchanges authorization code for tokens @@ -251,6 +259,30 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre tokenInfo.PlanType = userInfo.PlanType } + // id_token 中缺少 plan_type 时(如 Mobile RT),尝试通过 ChatGPT backend-api 补全 + if tokenInfo.PlanType == "" && tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { + // 从 access_token JWT 中提取 orgID(poid),用于匹配正确的账号 + orgID := tokenInfo.OrganizationID + if orgID == "" { + if atClaims, err := openai.DecodeIDToken(tokenInfo.AccessToken); err == nil && atClaims.OpenAIAuth != nil { + orgID = atClaims.OpenAIAuth.POID + } + } + if info := fetchChatGPTAccountInfo(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL, orgID); info != nil { + if tokenInfo.PlanType == "" && info.PlanType != "" { + tokenInfo.PlanType = info.PlanType + } + if tokenInfo.Email == "" && info.Email != "" { + tokenInfo.Email = info.Email + } + } + } + + // 尝试设置隐私(关闭训练数据共享),best-effort + if tokenInfo.AccessToken != "" && s.privacyClientFactory != nil { + tokenInfo.PrivacyMode = disableOpenAITraining(ctx, s.privacyClientFactory, tokenInfo.AccessToken, proxyURL) + } + return tokenInfo, nil } diff --git a/backend/internal/service/openai_privacy_service.go b/backend/internal/service/openai_privacy_service.go index 90cd522d..d5966006 100644 --- a/backend/internal/service/openai_privacy_service.go +++ b/backend/internal/service/openai_privacy_service.go @@ -69,6 +69,139 @@ func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFacto return PrivacyModeTrainingOff } +// ChatGPTAccountInfo 从 chatgpt.com/backend-api/accounts/check 获取的账号信息 +type ChatGPTAccountInfo struct { + PlanType string + Email string +} + +const chatGPTAccountsCheckURL = "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27" + +// fetchChatGPTAccountInfo calls ChatGPT backend-api to get account info (plan_type, etc.). +// Used as fallback when id_token doesn't contain these fields (e.g., Mobile RT). +// orgID is used to match the correct account when multiple accounts exist (e.g., personal + team). +// Returns nil on any failure (best-effort, non-blocking). +func fetchChatGPTAccountInfo(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL, orgID string) *ChatGPTAccountInfo { + if accessToken == "" || clientFactory == nil { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + client, err := clientFactory(proxyURL) + if err != nil { + slog.Debug("chatgpt_account_check_client_error", "error", err.Error()) + return nil + } + + var result map[string]any + resp, err := client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Origin", "https://chatgpt.com"). + SetHeader("Referer", "https://chatgpt.com/"). + SetHeader("Accept", "application/json"). + SetSuccessResult(&result). + Get(chatGPTAccountsCheckURL) + + if err != nil { + slog.Debug("chatgpt_account_check_request_error", "error", err.Error()) + return nil + } + + if !resp.IsSuccessState() { + slog.Debug("chatgpt_account_check_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200)) + return nil + } + + info := &ChatGPTAccountInfo{} + + accounts, ok := result["accounts"].(map[string]any) + if !ok { + slog.Debug("chatgpt_account_check_no_accounts", "body", truncate(resp.String(), 300)) + return nil + } + + // 优先匹配 orgID 对应的账号(access_token JWT 中的 poid) + if orgID != "" { + if matched := extractPlanFromAccount(accounts, orgID); matched != "" { + info.PlanType = matched + } + } + + // 未匹配到时,遍历所有账号:优先 is_default,次选非 free + if info.PlanType == "" { + var defaultPlan, paidPlan, anyPlan string + for _, acctRaw := range accounts { + acct, ok := acctRaw.(map[string]any) + if !ok { + continue + } + planType := extractPlanType(acct) + if planType == "" { + continue + } + if anyPlan == "" { + anyPlan = planType + } + if account, ok := acct["account"].(map[string]any); ok { + if isDefault, _ := account["is_default"].(bool); isDefault { + defaultPlan = planType + } + } + if !strings.EqualFold(planType, "free") && paidPlan == "" { + paidPlan = planType + } + } + // 优先级:default > 非 free > 任意 + switch { + case defaultPlan != "": + info.PlanType = defaultPlan + case paidPlan != "": + info.PlanType = paidPlan + default: + info.PlanType = anyPlan + } + } + + if info.PlanType == "" { + slog.Debug("chatgpt_account_check_no_plan_type", "body", truncate(resp.String(), 300)) + return nil + } + + slog.Info("chatgpt_account_check_success", "plan_type", info.PlanType, "org_id", orgID) + return info +} + +// extractPlanFromAccount 从 accounts map 中按 key(account_id)精确匹配并提取 plan_type +func extractPlanFromAccount(accounts map[string]any, accountKey string) string { + acctRaw, ok := accounts[accountKey] + if !ok { + return "" + } + acct, ok := acctRaw.(map[string]any) + if !ok { + return "" + } + return extractPlanType(acct) +} + +// extractPlanType 从单个 account 对象中提取 plan_type +func extractPlanType(acct map[string]any) string { + if account, ok := acct["account"].(map[string]any); ok { + if planType, ok := account["plan_type"].(string); ok && planType != "" { + return planType + } + } + if entitlement, ok := acct["entitlement"].(map[string]any); ok { + if subPlan, ok := entitlement["subscription_plan"].(string); ok && subPlan != "" { + return subPlan + } + } + return "" +} + func truncate(s string, n int) string { if len(s) <= n { return s diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go index 9a8803d3..a5b97ca9 100644 --- a/backend/internal/service/openai_ws_account_sticky_test.go +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -85,6 +85,58 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss( require.Zero(t, boundAccountID) } +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheckRateLimitedMiss(t *testing.T) { + ctx := context.Background() + groupID := int64(24) + rateLimitedUntil := time.Now().Add(30 * time.Minute) + staleAccount := &Account{ + ID: 13, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + dbAccount := Account{ + ID: 13, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + RateLimitResetAt: &rateLimitedUntil, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + snapshotCache := &openAISnapshotCacheStub{ + accountsByID: map[int64]*Account{dbAccount.ID: staleAccount}, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbAccount}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连") + boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl") + require.NoError(t, getErr) + require.Zero(t, boundAccountID) +} + func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { ctx := context.Background() groupID := int64(23) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 814ec0bd..4f1837c4 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -3846,6 +3846,11 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil, nil } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) + if account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index f5c79923..ffe79152 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -73,12 +73,13 @@ func (r *openAICodexExtraListRepo) SetRateLimited(_ context.Context, _ int64, re return nil } -func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { +func (r *openAICodexExtraListRepo) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) { _ = platform _ = accountType _ = status _ = search _ = groupID + _ = privacyMode return r.accounts, &pagination.PaginationResult{Total: int64(len(r.accounts)), Page: params.Page, PageSize: params.PageSize}, nil } @@ -491,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount( } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0) + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Len(t, accounts, 1) diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index a571dd4d..ad303d92 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "", 0) + }, platformFilter, "", "", "", 0, "") if err != nil { return nil, err } diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go index 2ed06d90..5fefb74f 100644 --- a/backend/internal/service/ops_models.go +++ b/backend/internal/service/ops_models.go @@ -62,6 +62,12 @@ type OpsErrorLog struct { ClientIP *string `json:"client_ip"` RequestPath string `json:"request_path"` Stream bool `json:"stream"` + + InboundEndpoint string `json:"inbound_endpoint"` + UpstreamEndpoint string `json:"upstream_endpoint"` + RequestedModel string `json:"requested_model"` + UpstreamModel string `json:"upstream_model"` + RequestType *int16 `json:"request_type"` } type OpsErrorLogDetail struct { diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 0ce9d425..04bf91c8 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -79,6 +79,17 @@ type OpsInsertErrorLogInput struct { Model string RequestPath string Stream bool + // InboundEndpoint is the normalized client-facing API endpoint path, e.g. /v1/chat/completions. + InboundEndpoint string + // UpstreamEndpoint is the normalized upstream endpoint path, e.g. /v1/responses. + UpstreamEndpoint string + // RequestedModel is the client-requested model name before mapping. + RequestedModel string + // UpstreamModel is the actual model sent to upstream after mapping. Empty means no mapping. + UpstreamModel string + // RequestType is the granular request type: 0=unknown, 1=sync, 2=stream, 3=ws_v2. + // Matches service.RequestType enum semantics from usage_log.go. + RequestType *int16 UserAgent string ErrorPhase string diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 9adf5896..05d444e1 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -93,6 +93,10 @@ type OpsUpstreamErrorEvent struct { UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"` + // UpstreamURL is the actual upstream URL that was called (host + path, query/fragment stripped). + // Helps debug 404/routing errors by showing which endpoint was targeted. + UpstreamURL string `json:"upstream_url,omitempty"` + // Best-effort upstream request capture (sanitized+trimmed). // Required for retrying a specific upstream attempt. UpstreamRequestBody string `json:"upstream_request_body,omitempty"` @@ -119,6 +123,7 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody) ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody) ev.Kind = strings.TrimSpace(ev.Kind) + ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL) ev.Message = strings.TrimSpace(ev.Message) ev.Detail = strings.TrimSpace(ev.Detail) if ev.Message != "" { @@ -205,3 +210,19 @@ func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) { } return out, nil } + +// safeUpstreamURL returns scheme + host + path from a URL, stripping query/fragment +// to avoid leaking sensitive query parameters (e.g. OAuth tokens). +func safeUpstreamURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + if idx := strings.IndexByte(rawURL, '?'); idx >= 0 { + rawURL = rawURL[:idx] + } + if idx := strings.IndexByte(rawURL, '#'); idx >= 0 { + rawURL = rawURL[:idx] + } + return rawURL +} diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go index 50ceaa0e..fa6d1085 100644 --- a/backend/internal/service/ops_upstream_context_test.go +++ b/backend/internal/service/ops_upstream_context_test.go @@ -8,6 +8,27 @@ import ( "github.com/stretchr/testify/require" ) +func TestSafeUpstreamURL(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"strips query", "https://api.anthropic.com/v1/messages?beta=true", "https://api.anthropic.com/v1/messages"}, + {"strips fragment", "https://api.openai.com/v1/responses#frag", "https://api.openai.com/v1/responses"}, + {"strips both", "https://host/path?token=secret#x", "https://host/path"}, + {"no query or fragment", "https://host/path", "https://host/path"}, + {"empty string", "", ""}, + {"whitespace only", " ", ""}, + {"query before fragment", "https://h/p?a=1#f", "https://h/p"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, safeUpstreamURL(tt.input)) + }) + } +} + func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 5c6c26e1..afe5816d 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -163,7 +163,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc account.Credentials = make(map[string]any) } account.Credentials["expires_at"] = time.Now().Format(time.RFC3339) - if err := s.accountRepo.Update(ctx, account); err != nil { + if err := persistAccountCredentials(ctx, s.accountRepo, account, account.Credentials); err != nil { slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err) } else { slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 4a6e5d6c..67b22e52 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -15,9 +15,11 @@ import ( type rateLimitAccountRepoStub struct { mockAccountRepoForGemini - setErrorCalls int - tempCalls int - lastErrorMsg string + setErrorCalls int + tempCalls int + updateCredentialsCalls int + lastCredentials map[string]any + lastErrorMsg string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -31,6 +33,12 @@ func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id return nil } +func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + r.updateCredentialsCalls++ + r.lastCredentials = cloneCredentials(credentials) + return nil +} + type tokenCacheInvalidatorRecorder struct { accounts []*Account err error @@ -110,6 +118,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin require.True(t, shouldDisable) require.Equal(t, 0, repo.setErrorCalls) require.Equal(t, 1, repo.tempCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) require.Len(t, invalidator.accounts, 1) } @@ -130,3 +139,22 @@ func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { require.Equal(t, 1, repo.setErrorCalls) require.Empty(t, invalidator.accounts) } + +func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 103, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.NotEmpty(t, repo.lastCredentials["expires_at"]) +} diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go index 1e990e3a..7796a85e 100644 --- a/backend/internal/service/ratelimit_session_window_test.go +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -81,7 +81,7 @@ func (m *sessionWindowMockRepo) Delete(context.Context, int64) error { panic( func (m *sessionWindowMockRepo) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } -func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]Account, *pagination.PaginationResult, error) { +func (m *sessionWindowMockRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { panic("unexpected") } func (m *sessionWindowMockRepo) ListByGroup(context.Context, int64) ([]Account, error) { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f652839c..44d20491 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -150,6 +150,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPurchaseSubscriptionURL, SettingKeySoraClientEnabled, SettingKeyCustomMenuItems, + SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, SettingKeyBackendModeEnabled, } @@ -195,6 +196,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", }, nil @@ -247,6 +249,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` + CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version,omitempty"` @@ -272,6 +275,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), + CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: s.version, @@ -314,6 +318,18 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } +// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". +func safeRawJSONArray(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" { + return json.RawMessage("[]") + } + if json.Valid([]byte(raw)) { + return json.RawMessage(raw) + } + return json.RawMessage("[]") +} + // GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url // and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { @@ -454,6 +470,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems + updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -740,6 +757,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyPurchaseSubscriptionURL: "", SettingKeySoraClientEnabled: "false", SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultSubscriptions: "[]", @@ -805,6 +823,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index cd0bed0b..cf1d5eed 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -43,6 +43,7 @@ type SystemSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items + CustomEndpoints string // JSON array of custom endpoints DefaultConcurrency int DefaultBalance float64 @@ -104,6 +105,7 @@ type PublicSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items + CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool BackendModeEnabled bool diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go index f9221c5b..6243f867 100644 --- a/backend/internal/service/sora_sdk_client.go +++ b/backend/internal/service/sora_sdk_client.go @@ -947,7 +947,7 @@ func (c *SoraSDKClient) applyRecoveredToken(ctx context.Context, account *Accoun } if c.accountRepo != nil { - if err := c.accountRepo.Update(ctx, account); err != nil && c.debugEnabled() { + if err := persistAccountCredentials(ctx, c.accountRepo, account, account.Credentials); err != nil && c.debugEnabled() { c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) } } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 582afcd3..24b7424f 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -280,8 +280,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc newCredentials, err = refresher.Refresh(ctx, account) if newCredentials != nil { newCredentials["_token_version"] = time.Now().UnixMilli() - account.Credentials = newCredentials - if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + if saveErr := persistAccountCredentials(ctx, s.accountRepo, account, newCredentials); saveErr != nil { return fmt.Errorf("failed to save credentials: %w", saveErr) } } diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index f48de65e..60ba4a96 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,19 +14,40 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - clearTempCalls int - lastAccount *Account - updateErr error + updateCalls int + fullUpdateCalls int + updateCredentialsCalls int + setErrorCalls int + clearTempCalls int + lastAccount *Account + updateErr error } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { r.updateCalls++ + r.fullUpdateCalls++ r.lastAccount = account return r.updateErr } +func (r *tokenRefreshAccountRepo) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error { + r.updateCalls++ + r.updateCredentialsCalls++ + if r.updateErr != nil { + return r.updateErr + } + cloned := cloneCredentials(credentials) + if r.accountsByID != nil { + if acc, ok := r.accountsByID[id]; ok && acc != nil { + acc.Credentials = cloned + r.lastAccount = acc + return nil + } + } + r.lastAccount = &Account{ID: id, Credentials: cloned} + return nil +} + func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrorCalls++ return nil @@ -112,6 +133,8 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.fullUpdateCalls) require.Equal(t, 1, invalidator.calls) require.Equal(t, "new-token", account.GetCredential("access_token")) } @@ -249,9 +272,43 @@ func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) { err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) + require.Equal(t, 1, repo.updateCredentialsCalls) require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效 } +func TestTokenRefreshService_RefreshWithRetry_UsesCredentialsUpdater(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + resetAt := time.Now().Add(30 * time.Minute) + account := &Account{ + ID: 17, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + RateLimitResetAt: &resetAt, + Credentials: map[string]any{ + "access_token": "old-token", + }, + } + refresher := &tokenRefresherStub{ + credentials: map[string]any{ + "access_token": "new-token", + }, + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.NoError(t, err) + require.Equal(t, 1, repo.updateCredentialsCalls) + require.Equal(t, 0, repo.fullUpdateCalls) + require.NotNil(t, account.RateLimitResetAt) + require.WithinDuration(t, resetAt, *account.RateLimitResetAt, time.Second) +} + // TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况 func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) { repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")} @@ -390,7 +447,7 @@ func TestTokenRefreshService_RefreshWithRetry_ClearsTempUnschedulable(t *testing err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) require.NoError(t, err) require.Equal(t, 1, repo.updateCalls) - require.Equal(t, 1, repo.clearTempCalls) // DB 清除 + require.Equal(t, 1, repo.clearTempCalls) // DB 清除 require.Equal(t, 1, tempCache.deleteCalls) // Redis 缓存也应清除 } diff --git a/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql b/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql new file mode 100644 index 00000000..56f83b84 --- /dev/null +++ b/backend/migrations/079_ops_error_logs_add_endpoint_fields.sql @@ -0,0 +1,28 @@ +-- Ops error logs: add endpoint, model mapping, and request_type fields +-- to match usage_logs observability coverage. +-- +-- All columns are nullable with no default to preserve backward compatibility +-- with existing rows. + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 1) Standardized endpoint paths (analogous to usage_logs.inbound_endpoint / upstream_endpoint) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS inbound_endpoint VARCHAR(256), + ADD COLUMN IF NOT EXISTS upstream_endpoint VARCHAR(256); + +-- 2) Model mapping fields (analogous to usage_logs.requested_model / upstream_model) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS requested_model VARCHAR(100), + ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); + +-- 3) Granular request type enum (analogous to usage_logs.request_type: 0=unknown, 1=sync, 2=stream, 3=ws_v2) +ALTER TABLE ops_error_logs + ADD COLUMN IF NOT EXISTS request_type SMALLINT; + +COMMENT ON COLUMN ops_error_logs.inbound_endpoint IS 'Normalized client-facing API endpoint path, e.g. /v1/chat/completions. Populated from InboundEndpointMiddleware.'; +COMMENT ON COLUMN ops_error_logs.upstream_endpoint IS 'Normalized upstream endpoint path derived from platform, e.g. /v1/responses.'; +COMMENT ON COLUMN ops_error_logs.requested_model IS 'Client-requested model name before mapping (raw from request body).'; +COMMENT ON COLUMN ops_error_logs.upstream_model IS 'Actual model sent to upstream provider after mapping. NULL means no mapping applied.'; +COMMENT ON COLUMN ops_error_logs.request_type IS 'Request type enum: 0=unknown, 1=sync, 2=stream, 3=ws_v2. Matches usage_logs.request_type semantics.'; diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 751da25f..ece5a30f 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -36,6 +36,7 @@ export async function list( status?: string group?: string search?: string + privacy_mode?: string lite?: string }, options?: { @@ -68,6 +69,7 @@ export async function listWithEtag( status?: string group?: string search?: string + privacy_mode?: string lite?: string }, options?: { @@ -550,14 +552,18 @@ export async function getAntigravityDefaultModelMapping(): Promise> { - const payload: { refresh_token: string; proxy_id?: number } = { + const payload: { refresh_token: string; proxy_id?: number; client_id?: string } = { refresh_token: refreshToken } if (proxyId) { payload.proxy_id = proxyId } + if (clientId) { + payload.client_id = clientId + } const { data } = await apiClient.post>(endpoint, payload) return data } diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index 64f6a6d0..ac58eff4 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -969,6 +969,13 @@ export interface OpsErrorLog { client_ip?: string | null request_path?: string stream?: boolean + + // Error observability context (endpoint + model mapping) + inbound_endpoint?: string + upstream_endpoint?: string + requested_model?: string + upstream_model?: string + request_type?: number | null } export interface OpsErrorDetail extends OpsErrorLog { diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 0519d2fc..83258bcc 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { CustomMenuItem } from '@/types' +import type { CustomMenuItem, CustomEndpoint } from '@/types' export interface DefaultSubscriptionSetting { group_id: number @@ -43,6 +43,7 @@ export interface SystemSettings { sora_client_enabled: boolean backend_mode_enabled: boolean custom_menu_items: CustomMenuItem[] + custom_endpoints: CustomEndpoint[] // SMTP settings smtp_host: string smtp_port: number @@ -112,6 +113,7 @@ export interface UpdateSettingsRequest { sora_client_enabled?: boolean backend_mode_enabled?: boolean custom_menu_items?: CustomMenuItem[] + custom_endpoints?: CustomEndpoint[] smtp_host?: string smtp_port?: number smtp_username?: string diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 68dc4fcc..2934fbd9 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -661,6 +661,43 @@ + +
+
+ + +
+
+

+ {{ t('admin.accounts.openai.wsModeDesc') }} +

+

+ {{ t(openAIWSModeConcurrencyHintKey) }} +

+ + {{ + t('admin.accounts.oauth.openai.mobileRefreshTokenAuth', '手动输入 Mobile RT') + }} +
- -
+ +
@@ -759,6 +770,7 @@ interface Props { methodLabel?: string showCookieOption?: boolean // Whether to show cookie auto-auth option showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only) + showMobileRefreshTokenOption?: boolean // Whether to show mobile refresh token option (OpenAI only) showSessionTokenOption?: boolean // Whether to show session token input option (Sora only) showAccessTokenOption?: boolean // Whether to show access token input option (Sora only) platform?: AccountPlatform // Platform type for different UI/text @@ -776,6 +788,7 @@ const props = withDefaults(defineProps(), { methodLabel: 'Authorization Method', showCookieOption: true, showRefreshTokenOption: false, + showMobileRefreshTokenOption: false, showSessionTokenOption: false, showAccessTokenOption: false, platform: 'anthropic', @@ -787,6 +800,7 @@ const emit = defineEmits<{ 'exchange-code': [code: string] 'cookie-auth': [sessionKey: string] 'validate-refresh-token': [refreshToken: string] + 'validate-mobile-refresh-token': [refreshToken: string] 'validate-session-token': [sessionToken: string] 'import-access-token': [accessToken: string] 'update:inputMethod': [method: AuthInputMethod] @@ -834,7 +848,7 @@ const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showMobileRefreshTokenOption || props.showSessionTokenOption || props.showAccessTokenOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -945,7 +959,11 @@ const handleCookieAuth = () => { const handleValidateRefreshToken = () => { if (refreshTokenInput.value.trim()) { - emit('validate-refresh-token', refreshTokenInput.value.trim()) + if (inputMethod.value === 'mobile_refresh_token') { + emit('validate-mobile-refresh-token', refreshTokenInput.value.trim()) + } else { + emit('validate-refresh-token', refreshTokenInput.value.trim()) + } } } diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts index 6458359e..7390e723 100644 --- a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts +++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts @@ -149,6 +149,35 @@ describe('BulkEditAccountModal', () => { }) }) + it('OpenAI OAuth 批量编辑应提交 OAuth 专属 WS mode 字段', async () => { + const wrapper = mountModal({ + selectedPlatforms: ['openai'], + selectedTypes: ['oauth'] + }) + + await wrapper.get('#bulk-edit-openai-ws-mode-enabled').setValue(true) + await wrapper.get('[data-testid="bulk-edit-openai-ws-mode-select"]').setValue('passthrough') + await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent') + await flushPromises() + + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1) + expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], { + extra: { + openai_oauth_responses_websockets_v2_mode: 'passthrough', + openai_oauth_responses_websockets_v2_enabled: true + } + }) + }) + + it('OpenAI API Key 批量编辑不显示 WS mode 入口', () => { + const wrapper = mountModal({ + selectedPlatforms: ['openai'], + selectedTypes: ['apikey'] + }) + + expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false) + }) + it('OpenAI 账号批量编辑可关闭自动透传', async () => { const wrapper = mountModal({ selectedPlatforms: ['openai'], diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index a2a9ab04..43e703ec 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -10,6 +10,7 @@ +
+
+ + +
+
+ + +
+
+
+
+ + + +
+
+
{{ t('admin.ops.errorDetail.requestType') }}
+
+ {{ formatRequestTypeLabel(detail.request_type) }} +
+
+
{{ t('admin.ops.errorDetail.message') }}
@@ -213,6 +241,31 @@ function isUpstreamError(d: OpsErrorDetail | null): boolean { return phase === 'upstream' && owner === 'provider' } +function formatRequestTypeLabel(type: number | null | undefined): string { + switch (type) { + case 1: return t('admin.ops.errorDetail.requestTypeSync') + case 2: return t('admin.ops.errorDetail.requestTypeStream') + case 3: return t('admin.ops.errorDetail.requestTypeWs') + default: return t('admin.ops.errorDetail.requestTypeUnknown') + } +} + +function hasModelMapping(d: OpsErrorDetail | null): boolean { + if (!d) return false + const requested = String(d.requested_model || '').trim() + const upstream = String(d.upstream_model || '').trim() + return !!requested && !!upstream && requested !== upstream +} + +function displayModel(d: OpsErrorDetail | null): string { + if (!d) return '' + const upstream = String(d.upstream_model || '').trim() + if (upstream) return upstream + const requested = String(d.requested_model || '').trim() + if (requested) return requested + return String(d.model || '').trim() +} + const correlatedUpstream = ref([]) const correlatedUpstreamLoading = ref(false) diff --git a/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue b/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue index 28868552..2b3825a2 100644 --- a/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue +++ b/frontend/src/views/admin/ops/components/OpsErrorLogTable.vue @@ -17,6 +17,9 @@ {{ t('admin.ops.errorLog.type') }} + + {{ t('admin.ops.errorLog.endpoint') }} + {{ t('admin.ops.errorLog.platform') }} @@ -42,7 +45,7 @@ - + {{ t('admin.ops.errorLog.noErrors') }} @@ -74,6 +77,18 @@ + + +
+ + + {{ log.inbound_endpoint }} + + + - +
+ + @@ -83,11 +98,22 @@ -
- - {{ log.model }} - - - +
+ +
@@ -138,6 +164,12 @@ > {{ log.severity }} + + {{ formatRequestType(log.request_type) }} +
@@ -193,6 +225,44 @@ function isUpstreamRow(log: OpsErrorLog): boolean { return phase === 'upstream' && owner === 'provider' } +function formatEndpointTooltip(log: OpsErrorLog): string { + const parts: string[] = [] + if (log.inbound_endpoint) parts.push(`Inbound: ${log.inbound_endpoint}`) + if (log.upstream_endpoint) parts.push(`Upstream: ${log.upstream_endpoint}`) + return parts.join('\n') || '' +} + +function hasModelMapping(log: OpsErrorLog): boolean { + const requested = String(log.requested_model || '').trim() + const upstream = String(log.upstream_model || '').trim() + return !!requested && !!upstream && requested !== upstream +} + +function modelMappingTooltip(log: OpsErrorLog): string { + const requested = String(log.requested_model || '').trim() + const upstream = String(log.upstream_model || '').trim() + if (!requested && !upstream) return '' + if (requested && upstream) return `${requested} → ${upstream}` + return upstream || requested +} + +function displayModel(log: OpsErrorLog): string { + const upstream = String(log.upstream_model || '').trim() + if (upstream) return upstream + const requested = String(log.requested_model || '').trim() + if (requested) return requested + return String(log.model || '').trim() +} + +function formatRequestType(type: number | null | undefined): string { + switch (type) { + case 1: return t('admin.ops.errorLog.requestTypeSync') + case 2: return t('admin.ops.errorLog.requestTypeStream') + case 3: return t('admin.ops.errorLog.requestTypeWs') + default: return '' + } +} + function getTypeBadge(log: OpsErrorLog): { label: string; className: string } { const phase = String(log.phase || '').toLowerCase() const owner = String(log.error_owner || '').toLowerCase() @@ -263,4 +333,4 @@ function formatSmartMessage(msg: string): string { return msg.length > 200 ? msg.substring(0, 200) + '...' : msg } - \ No newline at end of file + diff --git a/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue b/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue index a2f1adc3..d2aeb3ca 100644 --- a/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue +++ b/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue @@ -344,7 +344,7 @@ onMounted(async () => {
运行时日志配置(实时生效)
加载中...
-
+
-
- - - - +
+
+
+ + +
+
+ + +
+

最近写入错误:{{ health.last_error }}

diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 836fd2cb..c3335eb7 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -2,24 +2,31 @@ @@ -1050,6 +1057,7 @@ import TablePageLayout from '@/components/layout/TablePageLayout.vue' import SearchInput from '@/components/common/SearchInput.vue' import Icon from '@/components/icons/Icon.vue' import UseKeyModal from '@/components/keys/UseKeyModal.vue' + import EndpointPopover from '@/components/keys/EndpointPopover.vue' import GroupBadge from '@/components/common/GroupBadge.vue' import GroupOptionItem from '@/components/common/GroupOptionItem.vue' import type { ApiKey, Group, PublicSettings, SubscriptionType, GroupPlatform } from '@/types'