package service import ( "bufio" "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "log" "net/http" "regexp" "sort" "strconv" "strings" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" ) const ( // ChatGPT internal API for OAuth accounts chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses" // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL ) // openaiSSEDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) // OpenAI allowed headers whitelist (for non-OAuth accounts) var openaiAllowedHeaders = map[string]bool{ "accept-language": true, "content-type": true, "conversation_id": true, "user-agent": true, "originator": true, "session_id": true, } // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers type OpenAICodexUsageSnapshot struct { PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"` PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"` SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"` SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"` SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"` PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"` UpdatedAt string `json:"updated_at,omitempty"` } // NormalizedCodexLimits contains normalized 5h/7d rate limit data type NormalizedCodexLimits struct { Used5hPercent *float64 Reset5hSeconds *int Window5hMinutes *int Used7dPercent *float64 Reset7dSeconds *int Window7dMinutes *int } // Normalize converts primary/secondary fields to canonical 5h/7d fields. // Strategy: Compare window_minutes to determine which is 5h vs 7d. // Returns nil if snapshot is nil or has no useful data. func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits { if s == nil { return nil } result := &NormalizedCodexLimits{} primaryMins := 0 secondaryMins := 0 hasPrimaryWindow := false hasSecondaryWindow := false if s.PrimaryWindowMinutes != nil { primaryMins = *s.PrimaryWindowMinutes hasPrimaryWindow = true } if s.SecondaryWindowMinutes != nil { secondaryMins = *s.SecondaryWindowMinutes hasSecondaryWindow = true } // Determine mapping based on window_minutes use5hFromPrimary := false use7dFromPrimary := false if hasPrimaryWindow && hasSecondaryWindow { // Both known: smaller window is 5h, larger is 7d if primaryMins < secondaryMins { use5hFromPrimary = true } else { use7dFromPrimary = true } } else if hasPrimaryWindow { // Only primary known: classify by threshold (<=360 min = 6h -> 5h window) if primaryMins <= 360 { use5hFromPrimary = true } else { use7dFromPrimary = true } } else if hasSecondaryWindow { // Only secondary known: classify by threshold if secondaryMins <= 360 { // 5h from secondary, so primary (if any data) is 7d use7dFromPrimary = true } else { // 7d from secondary, so primary (if any data) is 5h use5hFromPrimary = true } } else { // No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h) use7dFromPrimary = true } // Assign values if use5hFromPrimary { result.Used5hPercent = s.PrimaryUsedPercent result.Reset5hSeconds = s.PrimaryResetAfterSeconds result.Window5hMinutes = s.PrimaryWindowMinutes result.Used7dPercent = s.SecondaryUsedPercent result.Reset7dSeconds = s.SecondaryResetAfterSeconds result.Window7dMinutes = s.SecondaryWindowMinutes } else if use7dFromPrimary { result.Used7dPercent = s.PrimaryUsedPercent result.Reset7dSeconds = s.PrimaryResetAfterSeconds result.Window7dMinutes = s.PrimaryWindowMinutes result.Used5hPercent = s.SecondaryUsedPercent result.Reset5hSeconds = s.SecondaryResetAfterSeconds result.Window5hMinutes = s.SecondaryWindowMinutes } return result } // OpenAIUsage represents OpenAI API response usage type OpenAIUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` } // OpenAIForwardResult represents the result of forwarding type OpenAIForwardResult struct { RequestID string Usage OpenAIUsage Model string // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool Duration time.Duration FirstTokenMs *int } // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository usageLogRepo UsageLogRepository userRepo UserRepository userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config schedulerSnapshot *SchedulerSnapshotService concurrencyService *ConcurrencyService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService httpUpstream HTTPUpstream deferredService *DeferredService openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector } // NewOpenAIGatewayService creates a new OpenAIGatewayService func NewOpenAIGatewayService( accountRepo AccountRepository, usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, httpUpstream HTTPUpstream, deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { return &OpenAIGatewayService{ accountRepo: accountRepo, usageLogRepo: usageLogRepo, userRepo: userRepo, userSubRepo: userSubRepo, cache: cache, cfg: cfg, schedulerSnapshot: schedulerSnapshot, concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, httpUpstream: httpUpstream, deferredService: deferredService, openAITokenProvider: openAITokenProvider, toolCorrector: NewCodexToolCorrector(), } } // GenerateSessionHash generates a sticky-session hash for OpenAI requests. // // Priority: // 1. Header: session_id // 2. Header: conversation_id // 3. Body: prompt_cache_key (opencode) func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string { if c == nil { return "" } sessionID := strings.TrimSpace(c.GetHeader("session_id")) if sessionID == "" { sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) } if sessionID == "" && reqBody != nil { if v, ok := reqBody["prompt_cache_key"].(string); ok { sessionID = strings.TrimSpace(v) } } if sessionID == "" { return "" } hash := sha256.Sum256([]byte(sessionID)) return hex.EncodeToString(hash[:]) } // BindStickySession sets session -> account binding with standard TTL. func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { return nil } return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) } // SelectAccount selects an OpenAI account with sticky session support func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") } // SelectAccountForModel selects an account supporting the requested model func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) } // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { cacheKey := "openai:" + sessionHash // 1. 尝试粘性会话命中 // Try sticky session hit if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { return account, nil } // 2. 获取可调度的 OpenAI 账号 // Get schedulable OpenAI accounts accounts, err := s.listSchedulableAccounts(ctx, groupID) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU selected := s.selectBestAccount(accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel) } return nil, errors.New("no available OpenAI accounts") } // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) } return selected, nil } // tryStickySessionHit 尝试从粘性会话获取账号。 // 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { if sessionHash == "" { return nil } accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) if err != nil || accountID <= 0 { return nil } if _, excluded := excludedIDs[accountID]; excluded { return nil } account, err := s.getSchedulableAccount(ctx, accountID) if err != nil { return nil } // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared if shouldClearStickySession(account) { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) return nil } // 验证账号是否可用于当前请求 // Verify account is usable for current request if !account.IsSchedulable() || !account.IsOpenAI() { return nil } if requestedModel != "" && !account.IsModelSupported(requestedModel) { return nil } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL) return account } // selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。 // 返回 nil 表示无可用账号。 // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account for i := range accounts { acc := &accounts[i] // 跳过被排除的账号 // Skip excluded accounts if _, excluded := excludedIDs[acc.ID]; excluded { continue } // 调度器快照可能暂时过时,这里重新检查可调度性和平台 // Scheduler snapshots can be temporarily stale; re-check schedulability and platform if !acc.IsSchedulable() || !acc.IsOpenAI() { continue } // 检查模型支持 // Check model support if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used if selected == nil { selected = acc continue } if s.isBetterAccount(acc, selected) { selected = acc } } return selected } // isBetterAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 // // isBetterAccount checks if candidate is better than current. // Rules: higher priority (lower value) wins; same priority: never used > least recently used. func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool { // 优先级更高(数值更小) // Higher priority (lower value) if candidate.Priority < current.Priority { return true } if candidate.Priority > current.Priority { return false } // 同优先级,比较最后使用时间 // Same priority, compare last used time switch { case candidate.LastUsedAt == nil && current.LastUsedAt != nil: // candidate 从未使用,优先 return true case candidate.LastUsedAt != nil && current.LastUsedAt == nil: // current 从未使用,保持 return false case candidate.LastUsedAt == nil && current.LastUsedAt == nil: // 都未使用,保持 return false default: // 都使用过,选择最久未使用的 return candidate.LastUsedAt.Before(*current.LastUsedAt) } } // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { stickyAccountID = accountID } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) if err != nil { return nil, err } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { return &AccountSelectionResult{ Account: account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: account.ID, MaxConcurrency: account.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } } return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: account.ID, MaxConcurrency: account.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, }, nil } accounts, err := s.listSchedulableAccounts(ctx, groupID) if err != nil { return nil, err } if len(accounts) == 0 { return nil, errors.New("no available accounts") } isExcluded := func(accountID int64) bool { if excludedIDs == nil { return false } _, excluded := excludedIDs[accountID] return excluded } // ============ Layer 1: Sticky session ============ if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { clearSticky := shouldClearStickySession(account) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: accountID, MaxConcurrency: account.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } } } } } // ============ Layer 2: Load-aware selection ============ candidates := make([]*Account, 0, len(accounts)) for i := range accounts { acc := &accounts[i] if isExcluded(acc.ID) { continue } // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. if !acc.IsSchedulable() { continue } if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } candidates = append(candidates, acc) } if len(candidates) == 0 { return nil, errors.New("no available accounts") } accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, MaxConcurrency: acc.Concurrency, }) } loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, false) for _, acc := range ordered { result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } } else { type accountWithLoad struct { account *Account loadInfo *AccountLoadInfo } var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] if loadInfo == nil { loadInfo = &AccountLoadInfo{AccountID: acc.ID} } if loadInfo.LoadRate < 100 { available = append(available, accountWithLoad{ account: acc, loadInfo: loadInfo, }) } } if len(available) > 0 { sort.SliceStable(available, func(i, j int) bool { a, b := available[i], available[j] if a.account.Priority != b.account.Priority { return a.account.Priority < b.account.Priority } if a.loadInfo.LoadRate != b.loadInfo.LoadRate { return a.loadInfo.LoadRate < b.loadInfo.LoadRate } switch { case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: return true case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: return false case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: return false default: return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) for _, item := range available { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } } } // ============ Layer 3: Fallback wait ============ sortAccountsByPriorityAndLastUsed(candidates, false) for _, acc := range candidates { return &AccountSelectionResult{ Account: acc, WaitPlan: &AccountWaitPlan{ AccountID: acc.ID, MaxConcurrency: acc.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, }, nil } return nil, errors.New("no available accounts") } func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { if s.schedulerSnapshot != nil { accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false) return accounts, err } var accounts []Account var err error if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } return accounts, nil } func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { if s.concurrencyService == nil { return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil } return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { if s.schedulerSnapshot != nil { return s.schedulerSnapshot.GetAccount(ctx, accountID) } return s.accountRepo.GetByID(ctx, accountID) } func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { if s.cfg != nil { return s.cfg.Gateway.Scheduling } return config.GatewaySchedulingConfig{ StickySessionMaxWaiting: 3, StickySessionWaitTimeout: 45 * time.Second, FallbackWaitTimeout: 30 * time.Second, FallbackMaxWaiting: 100, LoadBatchEnabled: true, SlotCleanupInterval: 30 * time.Second, } } // GetAccessToken gets the access token for an OpenAI account func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { case AccountTypeOAuth: // 使用 TokenProvider 获取缓存的 token if s.openAITokenProvider != nil { accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account) if err != nil { return "", "", err } return accessToken, "oauth", nil } // 降级:TokenProvider 未配置时直接从账号读取 accessToken := account.GetOpenAIAccessToken() if accessToken == "" { return "", "", errors.New("access_token not found in credentials") } return accessToken, "oauth", nil case AccountTypeAPIKey: apiKey := account.GetOpenAIApiKey() if apiKey == "" { return "", "", errors.New("api_key not found in credentials") } return apiKey, "apikey", nil default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } } func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { case 401, 402, 403, 429, 529: return true default: return statusCode >= 500 } } func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } // Forward forwards request to OpenAI API func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { startTime := time.Now() // Parse request body once (avoid multiple parse/serialize cycles) var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err != nil { return nil, fmt.Errorf("parse request: %w", err) } // Extract model and stream from parsed body reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) promptCacheKey := "" if v, ok := reqBody["prompt_cache_key"].(string); ok { promptCacheKey = strings.TrimSpace(v) } // Track if body needs re-serialization bodyModified := false originalModel := reqModel isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) reqBody["model"] = mappedModel bodyModified = true } // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { normalizedModel := normalizeCodexModel(model) if normalizedModel != "" && normalizedModel != model { log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", model, normalizedModel, account.Name, account.Type, isCodexCLI) reqBody["model"] = normalizedModel mappedModel = normalizedModel bodyModified = true } } // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" bodyModified = true log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) } } if account.Type == AccountTypeOAuth { codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true } if codexResult.NormalizedModel != "" { mappedModel = codexResult.NormalizedModel } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } } // Handle max_output_tokens based on platform and account type if !isCodexCLI { if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens { switch account.Platform { case PlatformOpenAI: // For OpenAI API Key, remove max_output_tokens (not supported) // For OpenAI OAuth (Responses API), keep it (supported) if account.Type == AccountTypeAPIKey { delete(reqBody, "max_output_tokens") bodyModified = true } case PlatformAnthropic: // For Anthropic (Claude), convert to max_tokens delete(reqBody, "max_output_tokens") if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { reqBody["max_tokens"] = maxOutputTokens } bodyModified = true case PlatformGemini: // For Gemini, remove (will be handled by Gemini-specific transform) delete(reqBody, "max_output_tokens") bodyModified = true default: // For unknown platforms, remove to be safe delete(reqBody, "max_output_tokens") bodyModified = true } } // Also handle max_completion_tokens (similar logic) if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens { if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { delete(reqBody, "max_completion_tokens") bodyModified = true } } // Remove prompt_cache_retention (not supported by upstream OpenAI API) if _, has := reqBody["prompt_cache_retention"]; has { delete(reqBody, "prompt_cache_retention") bodyModified = true } } // Re-serialize body only if modified if bodyModified { var err error body, err = json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("serialize request body: %w", err) } } // Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { return nil, err } // Build upstream request upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { return nil, err } // Get proxy URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } // Capture upstream request body for ops retry of this attempt. if c != nil { c.Set(OpsUpstreamRequestBodyKey, string(body)) } // Send request resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). safeErr := sanitizeUpstreamErrorMessage(err.Error()) setOpsUpstreamError(c, 0, safeErr, "") appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, Kind: "request_error", Message: safeErr, }) c.JSON(http.StatusBadGateway, gin.H{ "error": gin.H{ "type": "upstream_error", "message": "Upstream request failed", }, }) return nil, fmt.Errorf("upstream request failed: %s", safeErr) } defer func() { _ = resp.Body.Close() }() // Handle error response if resp.StatusCode >= 400 { if s.shouldFailoverUpstreamError(resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "failover", Message: upstreamMsg, Detail: upstreamDetail, }) s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } return s.handleErrorResponse(ctx, resp, c, account) } // Handle normal response var usage *OpenAIUsage var firstTokenMs *int if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) if err != nil { return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) if err != nil { return nil, err } } // Extract and save Codex usage snapshot from response headers (for OAuth accounts) if account.Type == AccountTypeOAuth { if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) } } reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) return &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, ReasoningEffort: reasoningEffort, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string switch account.Type { case AccountTypeOAuth: // OAuth accounts use ChatGPT internal API targetURL = chatgptCodexURL case AccountTypeAPIKey: // API Key accounts use Platform API or custom base URL baseURL := account.GetOpenAIBaseURL() if baseURL == "" { targetURL = openaiPlatformAPIURL } else { validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err } targetURL = validatedURL + "/responses" } default: targetURL = openaiPlatformAPIURL } req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err } // Set authentication header req.Header.Set("authorization", "Bearer "+token) // Set headers specific to OAuth accounts (ChatGPT internal API) if account.Type == AccountTypeOAuth { // Required: set Host for ChatGPT API (must use req.Host, not Header.Set) req.Host = "chatgpt.com" // Required: set chatgpt-account-id header chatgptAccountID := account.GetChatGPTAccountID() if chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } } // Whitelist passthrough headers for key, values := range c.Request.Header { lowerKey := strings.ToLower(key) if openaiAllowedHeaders[lowerKey] { for _, v := range values { req.Header.Add(key, v) } } } if account.Type == AccountTypeOAuth { req.Header.Set("OpenAI-Beta", "responses=experimental") if isCodexCLI { req.Header.Set("originator", "codex_cli_rs") } else { req.Header.Set("originator", "opencode") } req.Header.Set("accept", "text/event-stream") if promptCacheKey != "" { req.Header.Set("conversation_id", promptCacheKey) req.Header.Set("session_id", promptCacheKey) } } // Apply custom User-Agent if configured customUA := account.GetOpenAIUserAgent() if customUA != "" { req.Header.Set("user-agent", customUA) } // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") } return req, nil } func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(body), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { log.Printf( "OpenAI upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, account.Platform, account.Type, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } // Check custom error codes if !account.ShouldHandleErrorCode(resp.StatusCode) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "http_error", Message: upstreamMsg, Detail: upstreamDetail, }) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "type": "upstream_error", "message": "Upstream gateway error", }, }) if upstreamMsg == "" { return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) } // Handle upstream error (mark account status) shouldDisable := false if s.rateLimitService != nil { shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } kind := "http_error" if shouldDisable { kind = "failover" } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: kind, Message: upstreamMsg, Detail: upstreamDetail, }) if shouldDisable { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} } // Return appropriate error response var errType, errMsg string var statusCode int switch resp.StatusCode { case 401: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream authentication failed, please contact administrator" case 402: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream payment required: insufficient balance or billing issue" case 403: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream access forbidden, please contact administrator" case 429: statusCode = http.StatusTooManyRequests errType = "rate_limit_error" errMsg = "Upstream rate limit exceeded, please retry later" default: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream request failed" } c.JSON(statusCode, gin.H{ "error": gin.H{ "type": errType, "message": errMsg, }, }) if upstreamMsg == "" { return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } // openaiStreamingResult streaming response result type openaiStreamingResult struct { usage *OpenAIUsage firstTokenMs *int } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { if s.cfg != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) } // Set SSE response headers c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // Pass through other headers if v := resp.Header.Get("x-request-id"); v != "" { c.Header("x-request-id", v) } w := c.Writer flusher, ok := w.(http.Flusher) if !ok { return nil, errors.New("streaming not supported") } usage := &OpenAIUsage{} var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } scanner.Buffer(make([]byte, 64*1024), maxLineSize) type scanEvent struct { line string err error } // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 events := make(chan scanEvent, 16) done := make(chan struct{}) sendEvent := func(ev scanEvent) bool { select { case events <- ev: return true case <-done: return false } } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) go func() { defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } } if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } }() defer close(done) streamInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } // 仅监控上游数据间隔超时,不被下游写入阻塞影响 var intervalTicker *time.Ticker if streamInterval > 0 { intervalTicker = time.NewTicker(streamInterval) defer intervalTicker.Stop() } var intervalCh <-chan time.Time if intervalTicker != nil { intervalCh = intervalTicker.C } keepaliveInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second } // 下游 keepalive 仅用于防止代理空闲断开 var keepaliveTicker *time.Ticker if keepaliveInterval > 0 { keepaliveTicker = time.NewTicker(keepaliveInterval) defer keepaliveTicker.Stop() } var keepaliveCh <-chan time.Time if keepaliveTicker != nil { keepaliveCh = keepaliveTicker.C } // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 lastDataAt := time.Now() // 仅发送一次错误事件,避免多次写入导致协议混乱。 // 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema; // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return } errorEventSent = true payload := map[string]any{ "type": "error", "sequence_number": 0, "error": map[string]any{ "type": "upstream_error", "message": reason, "code": reason, }, } if b, err := json.Marshal(payload); err == nil { _, _ = fmt.Fprintf(w, "data: %s\n\n", b) flusher.Flush() } } needModelReplace := originalModel != mappedModel for { select { case ev, ok := <-events: if !ok { return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if ev.err != nil { // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { log.Printf("Context canceled during streaming, returning collected usage") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage if clientDisconnected { log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } sendErrorEvent("stream_read_error") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line lastDataAt = time.Now() // Extract data from SSE line (supports both "data: " and "data:" formats) if openaiSSEDataRe.MatchString(line) { data := openaiSSEDataRe.ReplaceAllString(line, "") // Replace model in response if needed if needModelReplace { line = s.replaceModelInSSELine(line, mappedModel, originalModel) } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { data = correctedData line = "data: " + correctedData } // 写入客户端(客户端断开后继续 drain 上游) if !clientDisconnected { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { clientDisconnected = true log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") } else { flusher.Flush() } } // Record first token time if firstTokenMs == nil && data != "" && data != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsage(data, usage) } else { // Forward non-data lines as-is if !clientDisconnected { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { clientDisconnected = true log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") } else { flusher.Flush() } } } case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) if time.Since(lastRead) < streamInterval { continue } if clientDisconnected { log.Printf("Upstream timeout after client disconnect, returning collected usage") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } sendErrorEvent("stream_timeout") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: if clientDisconnected { continue } if time.Since(lastDataAt) < keepaliveInterval { continue } if _, err := fmt.Fprint(w, ":\n\n"); err != nil { clientDisconnected = true log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") continue } flusher.Flush() } } } func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { if !openaiSSEDataRe.MatchString(line) { return line } data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { return line } var event map[string]any if err := json.Unmarshal([]byte(data), &event); err != nil { return line } // Replace model in response if m, ok := event["model"].(string); ok && m == fromModel { event["model"] = toModel newData, err := json.Marshal(event) if err != nil { return line } return "data: " + string(newData) } // Check nested response if response, ok := event["response"].(map[string]any); ok { if m, ok := response["model"].(string); ok && m == fromModel { response["model"] = toModel newData, err := json.Marshal(event) if err != nil { return line } return "data: " + string(newData) } } return line } // correctToolCallsInResponseBody 修正响应体中的工具调用 func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte { if len(body) == 0 { return body } bodyStr := string(body) corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr) if changed { return []byte(corrected) } return body } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { // Parse response.completed event for usage (OpenAI Responses format) var event struct { Type string `json:"type"` Response struct { Usage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` } `json:"input_tokens_details"` } `json:"usage"` } `json:"response"` } if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" { usage.InputTokens = event.Response.Usage.InputTokens usage.OutputTokens = event.Response.Usage.OutputTokens usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens } } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if account.Type == AccountTypeOAuth { bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE { return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel) } } // Parse usage var response struct { Usage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` } `json:"input_tokens_details"` } `json:"usage"` } if err := json.Unmarshal(body, &response); err != nil { return nil, fmt.Errorf("parse response: %w", err) } usage := &OpenAIUsage{ InputTokens: response.Usage.InputTokens, OutputTokens: response.Usage.OutputTokens, CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, } // Replace model in response if needed if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { contentType = upstreamType } } c.Data(resp.StatusCode, contentType, body) return usage, nil } func isEventStreamResponse(header http.Header) bool { contentType := strings.ToLower(header.Get("Content-Type")) return strings.Contains(contentType, "text/event-stream") } func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) usage := &OpenAIUsage{} if ok { var response struct { Usage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` } `json:"input_tokens_details"` } `json:"usage"` } if err := json.Unmarshal(finalResponse, &response); err == nil { usage.InputTokens = response.Usage.InputTokens usage.OutputTokens = response.Usage.OutputTokens usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens } body = finalResponse if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } // Correct tool calls in final response body = s.correctToolCallsInResponseBody(body) } else { usage = s.parseSSEUsageFromBody(bodyText) if originalModel != mappedModel { bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) } body = []byte(bodyText) } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) contentType := "application/json; charset=utf-8" if !ok { contentType = resp.Header.Get("Content-Type") if contentType == "" { contentType = "text/event-stream" } } c.Data(resp.StatusCode, contentType, body) return usage, nil } func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { if !openaiSSEDataRe.MatchString(line) { continue } data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } var event struct { Type string `json:"type"` Response json.RawMessage `json:"response"` } if json.Unmarshal([]byte(data), &event) != nil { continue } if event.Type == "response.done" || event.Type == "response.completed" { if len(event.Response) > 0 { return event.Response, true } } } return nil, false } func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") for _, line := range lines { if !openaiSSEDataRe.MatchString(line) { continue } data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } s.parseSSEUsage(data, usage) } return usage } func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { lines := strings.Split(body, "\n") for i, line := range lines { if !openaiSSEDataRe.MatchString(line) { continue } lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) } return strings.Join(lines, "\n") } func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) if err != nil { return "", fmt.Errorf("invalid base_url: %w", err) } return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, RequireAllowlist: true, AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { return "", fmt.Errorf("invalid base_url: %w", err) } return normalized, nil } func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { var resp map[string]any if err := json.Unmarshal(body, &resp); err != nil { return body } model, ok := resp["model"].(string) if !ok || model != fromModel { return body } resp["model"] = toModel newBody, err := json.Marshal(resp) if err != nil { return body } return newBody } // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult APIKey *APIKey User *User Account *Account Subscription *UserSubscription UserAgent string // 请求的 User-Agent IPAddress string // 请求的客户端 IP 地址 APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { result := input.Result apiKey := input.APIKey user := input.User account := input.Account subscription := input.Subscription // 计算实际的新输入token(减去缓存读取的token) // 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费 actualInputTokens := result.Usage.InputTokens - result.Usage.CacheReadInputTokens if actualInputTokens < 0 { actualInputTokens = 0 } // Calculate cost tokens := UsageTokens{ InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, } // Get rate multiplier multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier } cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { cost = &CostBreakdown{ActualCost: 0} } // Determine billing type isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() billingType := BillingTypeBalance if isSubscriptionBilling { billingType = BillingTypeSubscription } // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, CacheReadCost: cost.CacheReadCost, TotalCost: cost.TotalCost, ActualCost: cost.ActualCost, RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), } // 添加 UserAgent if input.UserAgent != "" { usageLog.UserAgent = &input.UserAgent } // 添加 IPAddress if input.IPAddress != "" { usageLog.IPAddress = &input.IPAddress } if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID } if subscription != nil { usageLog.SubscriptionID = &subscription.ID } inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } shouldBill := inserted || err != nil // Deduct based on billing type if isSubscriptionBilling { if shouldBill && cost.TotalCost > 0 { _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { if shouldBill && cost.ActualCost > 0 { _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } } // Update API key quota if applicable (only for balance mode with quota set) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { log.Printf("Update API key quota failed: %v", err) } } // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } // ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. // Exported for use in ratelimit_service when handling OpenAI 429 responses. func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { snapshot := &OpenAICodexUsageSnapshot{} hasData := false // Helper to parse float64 from header parseFloat := func(key string) *float64 { if v := headers.Get(key); v != "" { if f, err := strconv.ParseFloat(v, 64); err == nil { return &f } } return nil } // Helper to parse int from header parseInt := func(key string) *int { if v := headers.Get(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return &i } } return nil } // Primary (weekly) limits if v := parseFloat("x-codex-primary-used-percent"); v != nil { snapshot.PrimaryUsedPercent = v hasData = true } if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil { snapshot.PrimaryResetAfterSeconds = v hasData = true } if v := parseInt("x-codex-primary-window-minutes"); v != nil { snapshot.PrimaryWindowMinutes = v hasData = true } // Secondary (5h) limits if v := parseFloat("x-codex-secondary-used-percent"); v != nil { snapshot.SecondaryUsedPercent = v hasData = true } if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil { snapshot.SecondaryResetAfterSeconds = v hasData = true } if v := parseInt("x-codex-secondary-window-minutes"); v != nil { snapshot.SecondaryWindowMinutes = v hasData = true } // Overflow ratio if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil { snapshot.PrimaryOverSecondaryPercent = v hasData = true } if !hasData { return nil } snapshot.UpdatedAt = time.Now().Format(time.RFC3339) return snapshot } // updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { if snapshot == nil { return } // Convert snapshot to map for merging into Extra updates := make(map[string]any) // Save raw primary/secondary fields for debugging/tracing if snapshot.PrimaryUsedPercent != nil { updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent } if snapshot.PrimaryResetAfterSeconds != nil { updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds } if snapshot.PrimaryWindowMinutes != nil { updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes } if snapshot.SecondaryUsedPercent != nil { updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent } if snapshot.SecondaryResetAfterSeconds != nil { updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds } if snapshot.SecondaryWindowMinutes != nil { updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes } if snapshot.PrimaryOverSecondaryPercent != nil { updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent } updates["codex_usage_updated_at"] = snapshot.UpdatedAt // Normalize to canonical 5h/7d fields if normalized := snapshot.Normalize(); normalized != nil { if normalized.Used5hPercent != nil { updates["codex_5h_used_percent"] = *normalized.Used5hPercent } if normalized.Reset5hSeconds != nil { updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds } if normalized.Window5hMinutes != nil { updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes } if normalized.Used7dPercent != nil { updates["codex_7d_used_percent"] = *normalized.Used7dPercent } if normalized.Reset7dSeconds != nil { updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds } if normalized.Window7dMinutes != nil { updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes } } // Update account's Extra field asynchronously go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) }() } func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { if reqBody == nil { return "", false } // Primary: reasoning.effort if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { if effort, ok := reasoning["effort"].(string); ok { return normalizeOpenAIReasoningEffort(effort), true } } // Fallback: some clients may use a flat field. if effort, ok := reqBody["reasoning_effort"].(string); ok { return normalizeOpenAIReasoningEffort(effort), true } return "", false } func deriveOpenAIReasoningEffortFromModel(model string) string { if strings.TrimSpace(model) == "" { return "" } modelID := strings.TrimSpace(model) if strings.Contains(modelID, "/") { parts := strings.Split(modelID, "/") modelID = parts[len(parts)-1] } parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { switch r { case '-', '_', ' ': return true default: return false } }) if len(parts) == 0 { return "" } return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value == "" { return nil } return &value } value := deriveOpenAIReasoningEffortFromModel(requestedModel) if value == "" { return nil } return &value } func normalizeOpenAIReasoningEffort(raw string) string { value := strings.ToLower(strings.TrimSpace(raw)) if value == "" { return "" } // Normalize separators for "x-high"/"x_high" variants. value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value) switch value { case "none", "minimal": return "" case "low", "medium", "high": return value case "xhigh", "extrahigh": return "xhigh" default: // Only store known effort levels for now to keep UI consistent. return "" } }