merge: resolve conflict with main (keep both openAI probe and usage fix)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
httppool "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
openaipkg "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
@@ -92,10 +97,12 @@ type antigravityUsageCache struct {
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟,防止重试风暴
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟,打散并发请求避免反滥用检测
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
apiErrorCacheTTL = 1 * time.Minute // 负缓存 TTL:429 等错误缓存 1 分钟,防止重试风暴
|
||||
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟,打散并发请求避免反滥用检测
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
openAIProbeCacheTTL = 10 * time.Minute
|
||||
openAICodexProbeVersion = "0.104.0"
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
@@ -104,6 +111,7 @@ type UsageCache struct {
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
apiFlight singleflight.Group // 防止同一账号的并发请求击穿缓存
|
||||
openAIProbeCache sync.Map // accountID -> time.Time
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
@@ -231,6 +239,14 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth {
|
||||
usage, err := s.getOpenAIUsage(ctx, account)
|
||||
if err == nil {
|
||||
s.tryClearRecoverableAccountError(ctx, account)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
|
||||
if account.Platform == PlatformGemini {
|
||||
usage, err := s.getGeminiUsage(ctx, account)
|
||||
if err == nil {
|
||||
@@ -336,6 +352,161 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{UpdatedAt: &now}
|
||||
|
||||
if account == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
|
||||
if (usage.FiveHour == nil || usage.SevenDay == nil) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||
mergeAccountExtra(account, updates)
|
||||
if usage.UpdatedAt == nil {
|
||||
usage.UpdatedAt = &now
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||
usage.FiveHour = progress
|
||||
}
|
||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "7d", now); progress != nil {
|
||||
usage.SevenDay = progress
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.usageLogRepo == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-5*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.FiveHour == nil {
|
||||
usage.FiveHour = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.FiveHour.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
if stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, now.Add(-7*24*time.Hour)); err == nil {
|
||||
windowStats := windowStatsFromAccountStats(stats)
|
||||
if hasMeaningfulWindowStats(windowStats) {
|
||||
if usage.SevenDay == nil {
|
||||
usage.SevenDay = &UsageProgress{Utilization: 0}
|
||||
}
|
||||
usage.SevenDay.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
|
||||
if s == nil || s.cache == nil || accountID <= 0 {
|
||||
return true
|
||||
}
|
||||
if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok {
|
||||
if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL {
|
||||
return false
|
||||
}
|
||||
}
|
||||
s.cache.openAIProbeCache.Store(accountID, now)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
if account == nil || !account.IsOAuth() {
|
||||
return nil, nil
|
||||
}
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
modelID := openaipkg.DefaultTestModel
|
||||
payload := createOpenAITestPayload(modelID, true)
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create openai probe request: %w", err)
|
||||
}
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
req.Header.Set("Originator", "codex_cli_rs")
|
||||
req.Header.Set("Version", openAICodexProbeVersion)
|
||||
req.Header.Set("User-Agent", codexCLIUserAgent)
|
||||
if s.identityCache != nil {
|
||||
if fp, fpErr := s.identityCache.GetFingerprint(reqCtx, account.ID); fpErr == nil && fp != nil && strings.TrimSpace(fp.UserAgent) != "" {
|
||||
req.Header.Set("User-Agent", strings.TrimSpace(fp.UserAgent))
|
||||
}
|
||||
}
|
||||
if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
client, err := httppool.GetClient(httppool.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build openai probe client: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
updates := buildCodexUsageExtraUpdates(snapshot, time.Now())
|
||||
if len(updates) > 0 {
|
||||
go func(accountID int64, updates map[string]any) {
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||
}(account.ID, updates)
|
||||
return updates, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||
if account == nil || len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any, len(updates))
|
||||
}
|
||||
for k, v := range updates {
|
||||
account.Extra[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
@@ -567,6 +738,72 @@ func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
|
||||
}
|
||||
}
|
||||
|
||||
func hasMeaningfulWindowStats(stats *WindowStats) bool {
|
||||
if stats == nil {
|
||||
return false
|
||||
}
|
||||
return stats.Requests > 0 || stats.Tokens > 0 || stats.Cost > 0 || stats.StandardCost > 0 || stats.UserCost > 0
|
||||
}
|
||||
|
||||
func buildCodexUsageProgressFromExtra(extra map[string]any, window string, now time.Time) *UsageProgress {
|
||||
if len(extra) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
usedPercentKey string
|
||||
resetAfterKey string
|
||||
resetAtKey string
|
||||
)
|
||||
|
||||
switch window {
|
||||
case "5h":
|
||||
usedPercentKey = "codex_5h_used_percent"
|
||||
resetAfterKey = "codex_5h_reset_after_seconds"
|
||||
resetAtKey = "codex_5h_reset_at"
|
||||
case "7d":
|
||||
usedPercentKey = "codex_7d_used_percent"
|
||||
resetAfterKey = "codex_7d_reset_after_seconds"
|
||||
resetAtKey = "codex_7d_reset_at"
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
usedRaw, ok := extra[usedPercentKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
progress := &UsageProgress{Utilization: parseExtraFloat64(usedRaw)}
|
||||
if resetAtRaw, ok := extra[resetAtKey]; ok {
|
||||
if resetAt, err := parseTime(fmt.Sprint(resetAtRaw)); err == nil {
|
||||
progress.ResetsAt = &resetAt
|
||||
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
|
||||
if progress.RemainingSeconds < 0 {
|
||||
progress.RemainingSeconds = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if progress.ResetsAt == nil {
|
||||
if resetAfterSeconds := parseExtraInt(extra[resetAfterKey]); resetAfterSeconds > 0 {
|
||||
base := now
|
||||
if updatedAtRaw, ok := extra["codex_usage_updated_at"]; ok {
|
||||
if updatedAt, err := parseTime(fmt.Sprint(updatedAtRaw)); err == nil {
|
||||
base = updatedAt
|
||||
}
|
||||
}
|
||||
resetAt := base.Add(time.Duration(resetAfterSeconds) * time.Second)
|
||||
progress.ResetsAt = &resetAt
|
||||
progress.RemainingSeconds = int(time.Until(resetAt).Seconds())
|
||||
if progress.RemainingSeconds < 0 {
|
||||
progress.RemainingSeconds = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return progress
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
@@ -714,15 +951,30 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||
// 优先使用响应头中存储的真实 utilization 值(0-1 小数,转为 0-100 百分比)
|
||||
var utilization float64
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
default:
|
||||
utilization = 0.0
|
||||
var found bool
|
||||
if stored, ok := account.Extra["session_window_utilization"]; ok {
|
||||
switch v := stored.(type) {
|
||||
case float64:
|
||||
utilization = v * 100
|
||||
found = true
|
||||
case json.Number:
|
||||
if f, err := v.Float64(); err == nil {
|
||||
utilization = f * 100
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有存储的 utilization,回退到状态估算
|
||||
if !found {
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
}
|
||||
}
|
||||
|
||||
info.FiveHour = &UsageProgress{
|
||||
|
||||
@@ -49,7 +49,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
mappedModel := account.GetMappedModel(originalModel)
|
||||
responsesReq.Model = mappedModel
|
||||
|
||||
logger.L().Info("openai messages: model mapping applied",
|
||||
logger.L().Debug("openai messages: model mapping applied",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("original_model", originalModel),
|
||||
zap.String("mapped_model", mappedModel),
|
||||
@@ -67,7 +67,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
applyCodexOAuthTransform(reqBody, false, false)
|
||||
// OAuth codex transform forces stream=true upstream, so always use
|
||||
// the streaming response handler regardless of what the client asked.
|
||||
isStream = true
|
||||
@@ -148,9 +148,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
|
||||
// 9. Handle normal response
|
||||
if isStream {
|
||||
return s.handleAnthropicStreamingResponse(resp, c, originalModel, startTime)
|
||||
return s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
}
|
||||
return s.handleAnthropicNonStreamingResponse(resp, c, originalModel, startTime)
|
||||
return s.handleAnthropicNonStreamingResponse(resp, c, originalModel, mappedModel, startTime)
|
||||
}
|
||||
|
||||
// handleAnthropicErrorResponse reads an upstream error and returns it in
|
||||
@@ -200,6 +200,7 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
@@ -233,11 +234,12 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse(
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -247,6 +249,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
originalModel string,
|
||||
mappedModel string,
|
||||
startTime time.Time,
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
@@ -293,7 +296,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
}
|
||||
|
||||
// Extract usage from completion events
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete") &&
|
||||
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
|
||||
event.Response != nil && event.Response.Usage != nil {
|
||||
usage = OpenAIUsage{
|
||||
InputTokens: event.Response.Usage.InputTokens,
|
||||
@@ -324,6 +327,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -360,6 +364,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
|
||||
RequestID: requestID,
|
||||
Usage: usage,
|
||||
Model: originalModel,
|
||||
BillingModel: mappedModel,
|
||||
Stream: true,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
|
||||
@@ -207,12 +207,18 @@ type OpenAIUsage struct {
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
Model string // 原始模型(用于响应和日志显示)
|
||||
// BillingModel is the model used for cost calculation.
|
||||
// When non-empty, CalculateCost uses this instead of Model.
|
||||
// This is set by the Anthropic Messages conversion path where
|
||||
// the mapped upstream model differs from the client-facing model.
|
||||
BillingModel string
|
||||
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
|
||||
// Stored for usage records display; nil means not provided / not applicable.
|
||||
ReasoningEffort *string
|
||||
Stream bool
|
||||
OpenAIWSMode bool
|
||||
ResponseHeaders http.Header
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
@@ -3610,7 +3616,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
billingModel := result.Model
|
||||
if result.BillingModel != "" {
|
||||
billingModel = result.BillingModel
|
||||
}
|
||||
cost, err := s.billingService.CalculateCost(billingModel, tokens, multiplier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
@@ -3630,7 +3640,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
Model: billingModel,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
InputTokens: actualInputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -3875,6 +3885,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) {
|
||||
if accountID <= 0 || headers == nil {
|
||||
return
|
||||
}
|
||||
if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, accountID, snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
|
||||
if reqBody == nil {
|
||||
return "", false
|
||||
|
||||
@@ -28,6 +28,22 @@ type stubOpenAIAccountRepo struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
type snapshotUpdateAccountRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
updateExtraCalls chan map[string]any
|
||||
}
|
||||
|
||||
func (r *snapshotUpdateAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
if r.updateExtraCalls != nil {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtraCalls <- copied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
for i := range r.accounts {
|
||||
if r.accounts[i].ID == id {
|
||||
@@ -1248,6 +1264,30 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIUpdateCodexUsageSnapshotFromHeaders(t *testing.T) {
|
||||
repo := &snapshotUpdateAccountRepo{updateExtraCalls: make(chan map[string]any, 1)}
|
||||
svc := &OpenAIGatewayService{accountRepo: repo}
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "12")
|
||||
headers.Set("x-codex-secondary-used-percent", "34")
|
||||
headers.Set("x-codex-primary-window-minutes", "300")
|
||||
headers.Set("x-codex-secondary-window-minutes", "10080")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "600")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "86400")
|
||||
|
||||
svc.UpdateCodexUsageSnapshotFromHeaders(context.Background(), 123, headers)
|
||||
|
||||
select {
|
||||
case updates := <-repo.updateExtraCalls:
|
||||
require.Equal(t, 12.0, updates["codex_5h_used_percent"])
|
||||
require.Equal(t, 34.0, updates["codex_7d_used_percent"])
|
||||
require.Equal(t, 600, updates["codex_5h_reset_after_seconds"])
|
||||
require.Equal(t, 86400, updates["codex_7d_reset_after_seconds"])
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected UpdateExtra to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1334,6 +1374,7 @@ func TestOpenAIBuildUpstreamRequestPreservesCompactPathForAPIKeyBaseURL(t *testi
|
||||
|
||||
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
||||
|
||||
// ==================== P1-08 修复:model 替换性能优化测试 =============
|
||||
func TestReplaceModelInSSELine(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
|
||||
@@ -2309,6 +2309,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
|
||||
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: lease.HandshakeHeaders(),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
@@ -2919,6 +2920,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
|
||||
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
|
||||
Stream: reqStream,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: lease.HandshakeHeaders(),
|
||||
Duration: time.Since(turnStart),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
|
||||
@@ -126,6 +126,13 @@ func (l *openAIWSConnLease) HandshakeHeader(name string) string {
|
||||
return l.conn.handshakeHeader(name)
|
||||
}
|
||||
|
||||
func (l *openAIWSConnLease) HandshakeHeaders() http.Header {
|
||||
if l == nil || l.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return cloneHeader(l.conn.handshakeHeaders)
|
||||
}
|
||||
|
||||
func (l *openAIWSConnLease) IsPrewarmed() bool {
|
||||
if l == nil || l.conn == nil {
|
||||
return false
|
||||
|
||||
@@ -177,11 +177,12 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: turn.RequestModel,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
Duration: turn.Duration,
|
||||
FirstTokenMs: turn.FirstTokenMs,
|
||||
Model: turn.RequestModel,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
Duration: turn.Duration,
|
||||
FirstTokenMs: turn.FirstTokenMs,
|
||||
}
|
||||
logOpenAIWSV2Passthrough(
|
||||
"relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
|
||||
@@ -223,11 +224,12 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
|
||||
CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
|
||||
},
|
||||
Model: relayResult.RequestModel,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
Duration: relayResult.Duration,
|
||||
FirstTokenMs: relayResult.FirstTokenMs,
|
||||
Model: relayResult.RequestModel,
|
||||
Stream: true,
|
||||
OpenAIWSMode: true,
|
||||
ResponseHeaders: cloneHeader(handshakeHeaders),
|
||||
Duration: relayResult.Duration,
|
||||
FirstTokenMs: relayResult.FirstTokenMs,
|
||||
}
|
||||
|
||||
turnCount := int(completedTurns.Load())
|
||||
|
||||
@@ -970,12 +970,27 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
windowStart = &start
|
||||
windowEnd = &end
|
||||
slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
|
||||
// 窗口重置时清除旧的 utilization,避免残留上个窗口的数据
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||
"session_window_utilization": nil,
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
|
||||
// 存储真实的 utilization 值(0-1 小数),供 estimateSetupTokenUsage 使用
|
||||
if utilStr := headers.Get("anthropic-ratelimit-unified-5h-utilization"); utilStr != "" {
|
||||
if util, err := strconv.ParseFloat(utilStr, 64); err == nil {
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||
"session_window_utilization": util,
|
||||
}); err != nil {
|
||||
slog.Warn("session_window_utilization_update_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user