feat(rpm): RPM 限流模块优化
P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务
This commit is contained in:
@@ -183,6 +183,17 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
return map[string]any{"user_id": userID}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
|
||||
user, err := s.GetUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.UserRPMStatus{
|
||||
UserRPMUsed: 0,
|
||||
UserRPMLimit: user.RPMLimit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
|
||||
s.boundAuthIdentityFor = userID
|
||||
copied := input
|
||||
@@ -276,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
|
||||
s.lastListAccounts.platform = platform
|
||||
s.lastListAccounts.accountType = accountType
|
||||
|
||||
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
// 分组 RPM 上限(0 = 不限制)
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
|
||||
RequirePrivacySet *bool `json:"require_privacy_set"`
|
||||
DefaultMappedModel *string `json:"default_mapped_model"`
|
||||
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
|
||||
// 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
|
||||
RPMLimit *int `json:"rpm_limit"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||
RPMLimit: req.RPMLimit,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
RequirePrivacySet: req.RequirePrivacySet,
|
||||
DefaultMappedModel: req.DefaultMappedModel,
|
||||
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
|
||||
RPMLimit: req.RPMLimit,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
|
||||
}
|
||||
|
||||
// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
|
||||
type BatchSetGroupRPMOverridesRequest struct {
|
||||
Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
|
||||
}
|
||||
|
||||
// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
|
||||
// PUT /api/v1/admin/groups/:id/rpm-overrides
|
||||
func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchSetGroupRPMOverridesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
|
||||
}
|
||||
|
||||
// ClearGroupRPMOverrides handles clearing all rpm_override for a group
|
||||
// DELETE /api/v1/admin/groups/:id/rpm-overrides
|
||||
func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
|
||||
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid group ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
|
||||
}
|
||||
|
||||
// UpdateSortOrderRequest represents the request to update group sort orders
|
||||
type UpdateSortOrderRequest struct {
|
||||
Updates []struct {
|
||||
|
||||
@@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
@@ -332,6 +333,7 @@ type UpdateSettingsRequest struct {
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
@@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
@@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
|
||||
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
}
|
||||
|
||||
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
|
||||
Notes *string `json:"notes"`
|
||||
Balance *float64 `json:"balance"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
RPMLimit *int `json:"rpm_limit"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
@@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
RPMLimit: req.RPMLimit,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
RPMLimit: req.RPMLimit,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
@@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
|
||||
"migrated_keys": result.MigratedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
|
||||
// GET /api/v1/admin/users/:id/rpm-status
|
||||
func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, status)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
|
||||
TotalRecharged: u.TotalRecharged,
|
||||
RPMLimit: u.RPMLimit,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: g.RequireOAuthOnly,
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
RPMLimit: g.RPMLimit,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -108,6 +108,7 @@ type SystemSettings struct {
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
|
||||
@@ -26,6 +26,9 @@ type User struct {
|
||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
|
||||
TotalRecharged float64 `json:"total_recharged"`
|
||||
|
||||
// RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
|
||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
@@ -108,6 +111,9 @@ type Group struct {
|
||||
RequireOAuthOnly bool `json:"require_oauth_only"`
|
||||
RequirePrivacySet bool `json:"require_privacy_set"`
|
||||
|
||||
// RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
|
||||
RPMLimit int `json:"rpm_limit"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.errorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
@@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string) {
|
||||
func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
|
||||
if errors.Is(err, service.ErrBillingServiceUnavailable) {
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = "Billing service temporarily unavailable. Please retry later."
|
||||
}
|
||||
return http.StatusServiceUnavailable, "billing_service_error", msg
|
||||
return http.StatusServiceUnavailable, "billing_service_error", msg, 0
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||
}
|
||||
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
|
||||
}
|
||||
// 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
|
||||
// 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
|
||||
if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
|
||||
msg := pkgerrors.Message(err)
|
||||
retrySeconds := 60 - int(time.Now().Unix()%60)
|
||||
return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
@@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
).Warn("gateway.billing_error_missing_message")
|
||||
msg = "Billing error"
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
return http.StatusForbidden, "billing_error", msg, 0
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) metadataBridgeEnabled() bool {
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
|
||||
status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
|
||||
require.Equal(t, http.StatusTooManyRequests, status)
|
||||
require.Equal(t, "rate_limit_exceeded", code)
|
||||
require.NotEmpty(t, msg)
|
||||
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
|
||||
require.LessOrEqual(t, retryAfter, 60)
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
|
||||
status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
|
||||
require.Equal(t, http.StatusTooManyRequests, status)
|
||||
require.Equal(t, "rate_limit_exceeded", code)
|
||||
require.NotEmpty(t, msg)
|
||||
require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
|
||||
require.LessOrEqual(t, retryAfter, 60)
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
|
||||
// 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
|
||||
for _, err := range []error{
|
||||
service.ErrAPIKeyRateLimit5hExceeded,
|
||||
service.ErrAPIKeyRateLimit1dExceeded,
|
||||
service.ErrAPIKeyRateLimit7dExceeded,
|
||||
} {
|
||||
status, code, _, _ := billingErrorDetails(err)
|
||||
require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
|
||||
require.Equal(t, "rate_limit_exceeded", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
|
||||
status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
|
||||
require.Equal(t, http.StatusServiceUnavailable, status)
|
||||
require.Equal(t, "billing_service_error", code)
|
||||
require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
|
||||
}
|
||||
|
||||
func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
|
||||
status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
|
||||
require.Equal(t, http.StatusForbidden, status)
|
||||
require.Equal(t, "billing_error", code)
|
||||
require.NotEmpty(t, msg)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
// 2. Re-check billing
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.responsesErrorResponse(c, status, code, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message := billingErrorDetails(err)
|
||||
status, _, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
googleError(c, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 2. Re-check billing eligibility after wait
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
status, code, message, retryAfter := billingErrorDetails(err)
|
||||
if retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user