P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务
308 lines
10 KiB
Go
308 lines
10 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Responses handles OpenAI Responses API endpoint for Anthropic platform groups.
|
|
// POST /v1/responses
|
|
// This converts Responses API requests to Anthropic format, forwards to Anthropic
|
|
// upstream, and converts responses back to Responses format.
|
|
func (h *GatewayHandler) Responses(c *gin.Context) {
|
|
streamStarted := false
|
|
|
|
requestStart := time.Now()
|
|
|
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
|
if !ok {
|
|
h.responsesErrorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
|
return
|
|
}
|
|
|
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
|
if !ok {
|
|
h.responsesErrorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
|
return
|
|
}
|
|
reqLog := requestLogger(
|
|
c,
|
|
"handler.gateway.responses",
|
|
zap.Int64("user_id", subject.UserID),
|
|
zap.Int64("api_key_id", apiKey.ID),
|
|
zap.Any("group_id", apiKey.GroupID),
|
|
)
|
|
|
|
// Read request body
|
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
|
if err != nil {
|
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
|
h.responsesErrorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
|
return
|
|
}
|
|
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
|
return
|
|
}
|
|
|
|
if len(body) == 0 {
|
|
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
|
return
|
|
}
|
|
|
|
setOpsRequestContext(c, "", false, body)
|
|
|
|
// Validate JSON
|
|
if !gjson.ValidBytes(body) {
|
|
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
return
|
|
}
|
|
|
|
// Extract model and stream using gjson (like OpenAI handler)
|
|
modelResult := gjson.GetBytes(body, "model")
|
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
|
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
|
return
|
|
}
|
|
reqModel := modelResult.String()
|
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
|
|
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
|
|
|
// 解析渠道级模型映射
|
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
|
|
|
// Claude Code only restriction:
|
|
// /v1/responses is never a Claude Code endpoint.
|
|
// When claude_code_only is enabled, this endpoint is rejected.
|
|
// The existing service-layer checkClaudeCodeRestriction handles degradation
|
|
// to fallback groups when the Forward path calls SelectAccountForModelWithExclusions.
|
|
// Here we just reject at handler level since /v1/responses clients can't be Claude Code.
|
|
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
|
h.responsesErrorResponse(c, http.StatusForbidden, "permission_error",
|
|
"This group is restricted to Claude Code clients (/v1/messages only)")
|
|
return
|
|
}
|
|
|
|
// Error passthrough binding
|
|
if h.errorPassthroughService != nil {
|
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
|
}
|
|
|
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
|
|
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
|
|
|
// 1. Acquire user concurrency slot
|
|
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
|
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
|
waitCounted := false
|
|
if err != nil {
|
|
reqLog.Warn("gateway.responses.user_wait_counter_increment_failed", zap.Error(err))
|
|
} else if !canWait {
|
|
h.responsesErrorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
|
return
|
|
}
|
|
if err == nil && canWait {
|
|
waitCounted = true
|
|
}
|
|
defer func() {
|
|
if waitCounted {
|
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
|
}
|
|
}()
|
|
|
|
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
|
if err != nil {
|
|
reqLog.Warn("gateway.responses.user_slot_acquire_failed", zap.Error(err))
|
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
|
return
|
|
}
|
|
if waitCounted {
|
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
|
waitCounted = false
|
|
}
|
|
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
|
if userReleaseFunc != nil {
|
|
defer userReleaseFunc()
|
|
}
|
|
|
|
// 2. Re-check billing
|
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
|
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
|
|
status, code, message, retryAfter := billingErrorDetails(err)
|
|
if retryAfter > 0 {
|
|
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
|
}
|
|
h.responsesErrorResponse(c, status, code, message)
|
|
return
|
|
}
|
|
|
|
// Parse request for session hash
|
|
parsedReq, _ := service.ParseGatewayRequest(body, "responses")
|
|
if parsedReq == nil {
|
|
parsedReq = &service.ParsedRequest{Model: reqModel, Stream: reqStream, Body: body}
|
|
}
|
|
parsedReq.SessionContext = &service.SessionContext{
|
|
ClientIP: ip.GetClientIP(c),
|
|
UserAgent: c.GetHeader("User-Agent"),
|
|
APIKeyID: apiKey.ID,
|
|
}
|
|
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
|
|
|
// 3. Account selection + failover loop
|
|
fs := NewFailoverState(h.maxAccountSwitches, false)
|
|
|
|
for {
|
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
|
if err != nil {
|
|
if len(fs.FailedAccountIDs) == 0 {
|
|
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
|
return
|
|
}
|
|
action := fs.HandleSelectionExhausted(c.Request.Context())
|
|
switch action {
|
|
case FailoverContinue:
|
|
continue
|
|
case FailoverCanceled:
|
|
return
|
|
default:
|
|
if fs.LastFailoverErr != nil {
|
|
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
|
} else {
|
|
h.responsesErrorResponse(c, http.StatusBadGateway, "server_error", "All available accounts exhausted")
|
|
}
|
|
return
|
|
}
|
|
}
|
|
account := selection.Account
|
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
|
|
|
// 4. Acquire account concurrency slot
|
|
accountReleaseFunc := selection.ReleaseFunc
|
|
if !selection.Acquired {
|
|
if selection.WaitPlan == nil {
|
|
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
|
return
|
|
}
|
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
|
c,
|
|
account.ID,
|
|
selection.WaitPlan.MaxConcurrency,
|
|
selection.WaitPlan.Timeout,
|
|
reqStream,
|
|
&streamStarted,
|
|
)
|
|
if err != nil {
|
|
reqLog.Warn("gateway.responses.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
|
return
|
|
}
|
|
}
|
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
|
|
|
// 5. Forward request
|
|
writerSizeBeforeForward := c.Writer.Size()
|
|
forwardBody := body
|
|
if channelMapping.Mapped {
|
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
|
}
|
|
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
|
|
|
|
if accountReleaseFunc != nil {
|
|
accountReleaseFunc()
|
|
}
|
|
|
|
if err != nil {
|
|
var failoverErr *service.UpstreamFailoverError
|
|
if errors.As(err, &failoverErr) {
|
|
// Can't failover if streaming content already sent
|
|
if c.Writer.Size() != writerSizeBeforeForward {
|
|
h.handleResponsesFailoverExhausted(c, failoverErr, true)
|
|
return
|
|
}
|
|
action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
|
switch action {
|
|
case FailoverContinue:
|
|
continue
|
|
case FailoverExhausted:
|
|
h.handleResponsesFailoverExhausted(c, fs.LastFailoverErr, streamStarted)
|
|
return
|
|
case FailoverCanceled:
|
|
return
|
|
}
|
|
}
|
|
h.ensureForwardErrorResponse(c, streamStarted)
|
|
reqLog.Error("gateway.responses.forward_failed",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.Error(err),
|
|
)
|
|
return
|
|
}
|
|
|
|
// 6. Record usage
|
|
userAgent := c.GetHeader("User-Agent")
|
|
clientIP := ip.GetClientIP(c)
|
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
|
inboundEndpoint := GetInboundEndpoint(c)
|
|
upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
|
|
|
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
|
Result: result,
|
|
APIKey: apiKey,
|
|
User: apiKey.User,
|
|
Account: account,
|
|
Subscription: subscription,
|
|
InboundEndpoint: inboundEndpoint,
|
|
UpstreamEndpoint: upstreamEndpoint,
|
|
UserAgent: userAgent,
|
|
IPAddress: clientIP,
|
|
RequestPayloadHash: requestPayloadHash,
|
|
APIKeyService: h.apiKeyService,
|
|
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
|
|
}); err != nil {
|
|
reqLog.Error("gateway.responses.record_usage_failed",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
// responsesErrorResponse writes an error in OpenAI Responses API format.
|
|
func (h *GatewayHandler) responsesErrorResponse(c *gin.Context, status int, code, message string) {
|
|
c.JSON(status, gin.H{
|
|
"error": gin.H{
|
|
"code": code,
|
|
"message": message,
|
|
},
|
|
})
|
|
}
|
|
|
|
// handleResponsesFailoverExhausted writes a failover-exhausted error in Responses format.
|
|
func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastErr *service.UpstreamFailoverError, streamStarted bool) {
|
|
if streamStarted {
|
|
return // Can't write error after stream started
|
|
}
|
|
statusCode := http.StatusBadGateway
|
|
if lastErr != nil && lastErr.StatusCode > 0 {
|
|
statusCode = lastErr.StatusCode
|
|
}
|
|
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
|
}
|