Merge branch 'Wei-Shaw:main' into rebuild/auth-identity-foundation
This commit is contained in:
@@ -15,10 +15,12 @@ import (
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
const (
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointGeminiModels = "/v1beta/models"
|
||||
EndpointMessages = "/v1/messages"
|
||||
EndpointChatCompletions = "/v1/chat/completions"
|
||||
EndpointResponses = "/v1/responses"
|
||||
EndpointImagesGenerations = "/v1/images/generations"
|
||||
EndpointImagesEdits = "/v1/images/edits"
|
||||
EndpointGeminiModels = "/v1beta/models"
|
||||
)
|
||||
|
||||
// gin.Context keys used by the middleware and helpers below.
|
||||
@@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string {
|
||||
return EndpointChatCompletions
|
||||
case strings.Contains(path, EndpointMessages):
|
||||
return EndpointMessages
|
||||
case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"):
|
||||
return EndpointImagesGenerations
|
||||
case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"):
|
||||
return EndpointImagesEdits
|
||||
case strings.Contains(path, EndpointResponses):
|
||||
return EndpointResponses
|
||||
case strings.Contains(path, EndpointGeminiModels):
|
||||
@@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
||||
|
||||
switch platform {
|
||||
case service.PlatformOpenAI:
|
||||
if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
|
||||
return inbound
|
||||
}
|
||||
// OpenAI forwards everything to the Responses API.
|
||||
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||
|
||||
@@ -25,12 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
|
||||
{"/v1/messages", EndpointMessages},
|
||||
{"/v1/chat/completions", EndpointChatCompletions},
|
||||
{"/v1/responses", EndpointResponses},
|
||||
{"/v1/images/generations", EndpointImagesGenerations},
|
||||
{"/v1/images/edits", EndpointImagesEdits},
|
||||
{"/v1beta/models", EndpointGeminiModels},
|
||||
|
||||
// Prefixed paths (antigravity, openai).
|
||||
{"/antigravity/v1/messages", EndpointMessages},
|
||||
{"/openai/v1/responses", EndpointResponses},
|
||||
{"/openai/v1/responses/compact", EndpointResponses},
|
||||
{"/openai/v1/images/generations", EndpointImagesGenerations},
|
||||
{"/openai/v1/images/edits", EndpointImagesEdits},
|
||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||
|
||||
// Gin route patterns with wildcards.
|
||||
@@ -73,6 +77,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
|
||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
|
||||
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
|
||||
|
||||
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||
|
||||
300
backend/internal/handler/openai_images.go
Normal file
300
backend/internal/handler/openai_images.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Images handles OpenAI Images API requests.
|
||||
// POST /v1/images/generations
|
||||
// POST /v1/images/edits
|
||||
func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.images",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
|
||||
setOpsRequestContext(c, "", false, nil)
|
||||
} else {
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
}
|
||||
|
||||
parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
reqLog = reqLog.With(
|
||||
zap.String("model", parsed.Model),
|
||||
zap.Bool("stream", parsed.Stream),
|
||||
zap.Bool("multipart", parsed.Multipart),
|
||||
zap.String("capability", string(parsed.RequiredCapability)),
|
||||
)
|
||||
|
||||
if parsed.Multipart {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
|
||||
} else {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
|
||||
}
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
|
||||
|
||||
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := ""
|
||||
if parsed.Multipart {
|
||||
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
|
||||
} else {
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
sessionHash,
|
||||
parsed.Model,
|
||||
failedAccountIDs,
|
||||
parsed.RequiredCapability,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.images.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
reqLog.Debug("openai.images.account_schedule_decision",
|
||||
zap.String("layer", scheduleDecision.Layer),
|
||||
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
zap.Int("top_k", scheduleDecision.TopK),
|
||||
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.images.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
if account.Type == service.AccountTypeOAuth {
|
||||
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
if parsed.Multipart {
|
||||
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
|
||||
}
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.images"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", parsed.Model),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai.images.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
|
||||
reqLog.Debug("openai.images.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func isMultipartImagesContentType(contentType string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data")
|
||||
}
|
||||
@@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string {
|
||||
return service.PlatformAntigravity
|
||||
case strings.HasPrefix(p, "/v1beta/"):
|
||||
return service.PlatformGemini
|
||||
case strings.Contains(p, "/responses"):
|
||||
case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
|
||||
return service.PlatformOpenAI
|
||||
default:
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user