feat(rpm): RPM 限流模块优化

P0:
- rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7)
- 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数)

P1:
- ClearAll 按钮直连 DELETE API,带 loading 防重复
- 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点

优化:
- checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效
- Override/Group 变更后自动失效 auth cache
- fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
james-6-23
2026-04-23 03:33:52 +08:00
parent ef967d8f8a
commit dc5d42addc
79 changed files with 2831 additions and 140 deletions

View File

@@ -183,6 +183,17 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
user, err := s.GetUser(ctx, userID)
if err != nil {
return nil, err
}
return &service.UserRPMStatus{
UserRPMUsed: 0,
UserRPMLimit: user.RPMLimit,
}, nil
}
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
s.boundAuthIdentityFor = userID
copied := input
@@ -276,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil
}
func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
return nil
}
func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
s.lastListAccounts.platform = platform
s.lastListAccounts.accountType = accountType

View File

@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 分组 RPM 上限0 = 不限制)
RPMLimit int `json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 分组 RPM 上限0 = 不限制nil 表示未提供不改动
RPMLimit *int `json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
type BatchSetGroupRPMOverridesRequest struct {
Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
}
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
// PUT /api/v1/admin/groups/:id/rpm-overrides
func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRPMOverridesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
}
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
// DELETE /api/v1/admin/groups/:id/rpm-overrides
func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {

View File

@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,

View File

@@ -40,6 +40,7 @@ type CreateUserRequest struct {
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
RPMLimit int `json:"rpm_limit"`
AllowedGroups []int64 `json:"allowed_groups"`
}
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
RPMLimit *int `json:"rpm_limit"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
RPMLimit: req.RPMLimit,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
RPMLimit: req.RPMLimit,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys": result.MigratedKeys,
})
}
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
// GET /api/v1/admin/users/:id/rpm-status
func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, status)
}

View File

@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
RPMLimit: u.RPMLimit,
}
}
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
AllowMessagesDispatch: g.AllowMessagesDispatch,
RequireOAuthOnly: g.RequireOAuthOnly,
RequirePrivacySet: g.RequirePrivacySet,
RPMLimit: g.RPMLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -108,6 +108,7 @@ type SystemSettings struct {
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration

View File

@@ -26,6 +26,9 @@ type User struct {
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
// RPMLimit 用户级每分钟请求数上限0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
RPMLimit int `json:"rpm_limit"`
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
@@ -108,6 +111,9 @@ type Group struct {
RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"`
// RPMLimit 分组级每分钟请求数上限0 = 不限制),设置后覆盖用户级 rpm_limit。
RPMLimit int `json:"rpm_limit"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.errorResponse(c, status, code, message)
return
}
@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c.JSON(http.StatusOK, response)
}
func billingErrorDetails(err error) (status int, code, message string) {
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
return http.StatusServiceUnavailable, "billing_service_error", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
msg := pkgerrors.Message(err)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
// 用户/分组 RPM 超限统一映射为 HTTP 429保留与其它 rate_limit 一致的错误码便于客户端分类。
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
msg := pkgerrors.Message(err)
retrySeconds := 60 - int(time.Now().Unix()%60)
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
}
msg := pkgerrors.Message(err)
if msg == "" {
@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
return http.StatusForbidden, "billing_error", msg, 0
}
func (h *GatewayHandler) metadataBridgeEnabled() bool {

View File

@@ -0,0 +1,54 @@
package handler
import (
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
require.Equal(t, http.StatusTooManyRequests, status)
require.Equal(t, "rate_limit_exceeded", code)
require.NotEmpty(t, msg)
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
require.LessOrEqual(t, retryAfter, 60)
}
func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
require.Equal(t, http.StatusTooManyRequests, status)
require.Equal(t, "rate_limit_exceeded", code)
require.NotEmpty(t, msg)
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
require.LessOrEqual(t, retryAfter, 60)
}
func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
for _, err := range []error{
service.ErrAPIKeyRateLimit5hExceeded,
service.ErrAPIKeyRateLimit1dExceeded,
service.ErrAPIKeyRateLimit7dExceeded,
} {
status, code, _, _ := billingErrorDetails(err)
require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
require.Equal(t, "rate_limit_exceeded", code)
}
}
func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
require.Equal(t, http.StatusServiceUnavailable, status)
require.Equal(t, "billing_service_error", code)
require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
}
func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
require.Equal(t, http.StatusForbidden, status)
require.Equal(t, "billing_error", code)
require.NotEmpty(t, msg)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.chatCompletionsErrorResponse(c, status, code, message)
return
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.responsesErrorResponse(c, status, code, message)
return
}

View File

@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)

View File

@@ -9,6 +9,7 @@ import (
"errors"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
status, _, message := billingErrorDetails(err)
status, _, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
googleError(c, status, message)
return
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}

View File

@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
return
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"strconv"
"strings"
"time"
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
status, code, message, retryAfter := billingErrorDetails(err)
if retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}