Merge branch 'Wei-Shaw:main' into rebuild/auth-identity-foundation
This commit is contained in:
@@ -15,10 +15,12 @@ import (
|
|||||||
// ──────────────────────────────────────────────────────────
|
// ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EndpointMessages = "/v1/messages"
|
EndpointMessages = "/v1/messages"
|
||||||
EndpointChatCompletions = "/v1/chat/completions"
|
EndpointChatCompletions = "/v1/chat/completions"
|
||||||
EndpointResponses = "/v1/responses"
|
EndpointResponses = "/v1/responses"
|
||||||
EndpointGeminiModels = "/v1beta/models"
|
EndpointImagesGenerations = "/v1/images/generations"
|
||||||
|
EndpointImagesEdits = "/v1/images/edits"
|
||||||
|
EndpointGeminiModels = "/v1beta/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
// gin.Context keys used by the middleware and helpers below.
|
// gin.Context keys used by the middleware and helpers below.
|
||||||
@@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string {
|
|||||||
return EndpointChatCompletions
|
return EndpointChatCompletions
|
||||||
case strings.Contains(path, EndpointMessages):
|
case strings.Contains(path, EndpointMessages):
|
||||||
return EndpointMessages
|
return EndpointMessages
|
||||||
|
case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"):
|
||||||
|
return EndpointImagesGenerations
|
||||||
|
case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"):
|
||||||
|
return EndpointImagesEdits
|
||||||
case strings.Contains(path, EndpointResponses):
|
case strings.Contains(path, EndpointResponses):
|
||||||
return EndpointResponses
|
return EndpointResponses
|
||||||
case strings.Contains(path, EndpointGeminiModels):
|
case strings.Contains(path, EndpointGeminiModels):
|
||||||
@@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
|
|||||||
|
|
||||||
switch platform {
|
switch platform {
|
||||||
case service.PlatformOpenAI:
|
case service.PlatformOpenAI:
|
||||||
|
if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
|
||||||
|
return inbound
|
||||||
|
}
|
||||||
// OpenAI forwards everything to the Responses API.
|
// OpenAI forwards everything to the Responses API.
|
||||||
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
// Preserve subresource suffix (e.g. /v1/responses/compact).
|
||||||
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
|
||||||
|
|||||||
@@ -25,12 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
|
|||||||
{"/v1/messages", EndpointMessages},
|
{"/v1/messages", EndpointMessages},
|
||||||
{"/v1/chat/completions", EndpointChatCompletions},
|
{"/v1/chat/completions", EndpointChatCompletions},
|
||||||
{"/v1/responses", EndpointResponses},
|
{"/v1/responses", EndpointResponses},
|
||||||
|
{"/v1/images/generations", EndpointImagesGenerations},
|
||||||
|
{"/v1/images/edits", EndpointImagesEdits},
|
||||||
{"/v1beta/models", EndpointGeminiModels},
|
{"/v1beta/models", EndpointGeminiModels},
|
||||||
|
|
||||||
// Prefixed paths (antigravity, openai).
|
// Prefixed paths (antigravity, openai).
|
||||||
{"/antigravity/v1/messages", EndpointMessages},
|
{"/antigravity/v1/messages", EndpointMessages},
|
||||||
{"/openai/v1/responses", EndpointResponses},
|
{"/openai/v1/responses", EndpointResponses},
|
||||||
{"/openai/v1/responses/compact", EndpointResponses},
|
{"/openai/v1/responses/compact", EndpointResponses},
|
||||||
|
{"/openai/v1/images/generations", EndpointImagesGenerations},
|
||||||
|
{"/openai/v1/images/edits", EndpointImagesEdits},
|
||||||
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
|
||||||
|
|
||||||
// Gin route patterns with wildcards.
|
// Gin route patterns with wildcards.
|
||||||
@@ -73,6 +77,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
|
|||||||
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
|
||||||
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
|
||||||
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
|
||||||
|
{"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
|
||||||
|
{"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
|
||||||
|
|
||||||
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
|
||||||
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
|
||||||
|
|||||||
300
backend/internal/handler/openai_images.go
Normal file
300
backend/internal/handler/openai_images.go
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Images handles OpenAI Images API requests.
|
||||||
|
// POST /v1/images/generations
|
||||||
|
// POST /v1/images/edits
|
||||||
|
func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||||
|
streamStarted := false
|
||||||
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||||
|
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog := requestLogger(
|
||||||
|
c,
|
||||||
|
"handler.openai_gateway.images",
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
)
|
||||||
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||||
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
|
||||||
|
setOpsRequestContext(c, "", false, nil)
|
||||||
|
} else {
|
||||||
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
|
||||||
|
if err != nil {
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqLog = reqLog.With(
|
||||||
|
zap.String("model", parsed.Model),
|
||||||
|
zap.Bool("stream", parsed.Stream),
|
||||||
|
zap.Bool("multipart", parsed.Multipart),
|
||||||
|
zap.String("capability", string(parsed.RequiredCapability)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if parsed.Multipart {
|
||||||
|
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
|
||||||
|
} else {
|
||||||
|
setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
|
||||||
|
}
|
||||||
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
|
||||||
|
|
||||||
|
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model)
|
||||||
|
|
||||||
|
if h.errorPassthroughService != nil {
|
||||||
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
|
}
|
||||||
|
|
||||||
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
|
routingStart := time.Now()
|
||||||
|
|
||||||
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userReleaseFunc != nil {
|
||||||
|
defer userReleaseFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||||
|
reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
|
||||||
|
status, code, message := billingErrorDetails(err)
|
||||||
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionHash := ""
|
||||||
|
if parsed.Multipart {
|
||||||
|
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
|
||||||
|
} else {
|
||||||
|
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
|
switchCount := 0
|
||||||
|
failedAccountIDs := make(map[int64]struct{})
|
||||||
|
sameAccountRetryCount := make(map[int64]int)
|
||||||
|
var lastFailoverErr *service.UpstreamFailoverError
|
||||||
|
|
||||||
|
for {
|
||||||
|
reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||||
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages(
|
||||||
|
c.Request.Context(),
|
||||||
|
apiKey.GroupID,
|
||||||
|
sessionHash,
|
||||||
|
parsed.Model,
|
||||||
|
failedAccountIDs,
|
||||||
|
parsed.RequiredCapability,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
reqLog.Warn("openai.images.account_select_failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||||
|
)
|
||||||
|
if len(failedAccountIDs) == 0 {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lastFailoverErr != nil {
|
||||||
|
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||||
|
} else {
|
||||||
|
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqLog.Debug("openai.images.account_schedule_decision",
|
||||||
|
zap.String("layer", scheduleDecision.Layer),
|
||||||
|
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||||||
|
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||||
|
zap.Int("top_k", scheduleDecision.TopK),
|
||||||
|
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||||||
|
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||||
|
)
|
||||||
|
|
||||||
|
account := selection.Account
|
||||||
|
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||||
|
reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||||
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||||
|
|
||||||
|
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog)
|
||||||
|
if !acquired {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
|
forwardStart := time.Now()
|
||||||
|
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||||
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
|
if accountReleaseFunc != nil {
|
||||||
|
accountReleaseFunc()
|
||||||
|
}
|
||||||
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
responseLatencyMs := forwardDurationMs
|
||||||
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||||
|
}
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||||
|
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var failoverErr *service.UpstreamFailoverError
|
||||||
|
if errors.As(err, &failoverErr) {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
if failoverErr.RetryableOnSameAccount {
|
||||||
|
retryLimit := account.GetPoolModeRetryCount()
|
||||||
|
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||||
|
sameAccountRetryCount[account.ID]++
|
||||||
|
reqLog.Warn("openai.images.pool_mode_same_account_retry",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("retry_limit", retryLimit),
|
||||||
|
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||||
|
)
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
return
|
||||||
|
case <-time.After(sameAccountRetryDelay):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||||
|
failedAccountIDs[account.ID] = struct{}{}
|
||||||
|
lastFailoverErr = failoverErr
|
||||||
|
if switchCount >= maxAccountSwitches {
|
||||||
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switchCount++
|
||||||
|
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
zap.Int("max_switches", maxAccountSwitches),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||||
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||||
|
zap.Error(err),
|
||||||
|
}
|
||||||
|
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||||
|
reqLog.Warn("openai.images.forward_failed", fields...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqLog.Error("openai.images.forward_failed", fields...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != nil {
|
||||||
|
if account.Type == service.AccountTypeOAuth {
|
||||||
|
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
|
||||||
|
}
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||||
|
} else {
|
||||||
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
clientIP := ip.GetClientIP(c)
|
||||||
|
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||||
|
if parsed.Multipart {
|
||||||
|
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
|
||||||
|
}
|
||||||
|
|
||||||
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
|
Result: result,
|
||||||
|
APIKey: apiKey,
|
||||||
|
User: apiKey.User,
|
||||||
|
Account: account,
|
||||||
|
Subscription: subscription,
|
||||||
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
|
UserAgent: userAgent,
|
||||||
|
IPAddress: clientIP,
|
||||||
|
RequestPayloadHash: requestPayloadHash,
|
||||||
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
|
||||||
|
}); err != nil {
|
||||||
|
logger.L().With(
|
||||||
|
zap.String("component", "handler.openai_gateway.images"),
|
||||||
|
zap.Int64("user_id", subject.UserID),
|
||||||
|
zap.Int64("api_key_id", apiKey.ID),
|
||||||
|
zap.Any("group_id", apiKey.GroupID),
|
||||||
|
zap.String("model", parsed.Model),
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
).Error("openai.images.record_usage_failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
reqLog.Debug("openai.images.request_completed",
|
||||||
|
zap.Int64("account_id", account.ID),
|
||||||
|
zap.Int("switch_count", switchCount),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMultipartImagesContentType(contentType string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data")
|
||||||
|
}
|
||||||
@@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string {
|
|||||||
return service.PlatformAntigravity
|
return service.PlatformAntigravity
|
||||||
case strings.HasPrefix(p, "/v1beta/"):
|
case strings.HasPrefix(p, "/v1beta/"):
|
||||||
return service.PlatformGemini
|
return service.PlatformGemini
|
||||||
case strings.Contains(p, "/responses"):
|
case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
|
||||||
return service.PlatformOpenAI
|
return service.PlatformOpenAI
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ var DefaultModels = []Model{
|
|||||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||||
|
{ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultModelIDs returns the default model ID list
|
// DefaultModelIDs returns the default model ID list
|
||||||
|
|||||||
@@ -96,7 +96,8 @@ func isAPIRoutePath(c *gin.Context) bool {
|
|||||||
return strings.HasPrefix(path, "/v1/") ||
|
return strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
strings.HasPrefix(path, "/antigravity/") ||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/responses")
|
strings.HasPrefix(path, "/responses") ||
|
||||||
|
strings.HasPrefix(path, "/images")
|
||||||
}
|
}
|
||||||
|
|
||||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
|
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
|
||||||
|
|||||||
@@ -88,6 +88,30 @@ func RegisterGatewayRoutes(
|
|||||||
}
|
}
|
||||||
h.Gateway.ChatCompletions(c)
|
h.Gateway.ChatCompletions(c)
|
||||||
})
|
})
|
||||||
|
gateway.POST("/images/generations", func(c *gin.Context) {
|
||||||
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "not_found_error",
|
||||||
|
"message": "Images API is not supported for this platform",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.OpenAIGateway.Images(c)
|
||||||
|
})
|
||||||
|
gateway.POST("/images/edits", func(c *gin.Context) {
|
||||||
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "not_found_error",
|
||||||
|
"message": "Images API is not supported for this platform",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.OpenAIGateway.Images(c)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||||
@@ -124,6 +148,30 @@ func RegisterGatewayRoutes(
|
|||||||
}
|
}
|
||||||
h.Gateway.ChatCompletions(c)
|
h.Gateway.ChatCompletions(c)
|
||||||
})
|
})
|
||||||
|
r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
|
||||||
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "not_found_error",
|
||||||
|
"message": "Images API is not supported for this platform",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.OpenAIGateway.Images(c)
|
||||||
|
})
|
||||||
|
r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
|
||||||
|
if getGroupPlatform(c) != service.PlatformOpenAI {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"type": "not_found_error",
|
||||||
|
"message": "Images API is not supported for this platform",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.OpenAIGateway.Images(c)
|
||||||
|
})
|
||||||
|
|
||||||
// Antigravity 模型列表
|
// Antigravity 模型列表
|
||||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -24,6 +25,11 @@ func newGatewayRoutesTestRouter() *gin.Engine {
|
|||||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||||
},
|
},
|
||||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||||
|
groupID := int64(1)
|
||||||
|
c.Set(string(servermiddleware.ContextKeyAPIKey), &service.APIKey{
|
||||||
|
GroupID: &groupID,
|
||||||
|
Group: &service.Group{Platform: service.PlatformOpenAI},
|
||||||
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
}),
|
}),
|
||||||
nil,
|
nil,
|
||||||
@@ -48,3 +54,21 @@ func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
|
|||||||
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
|
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayRoutesOpenAIImagesPathsAreRegistered(t *testing.T) {
|
||||||
|
router := newGatewayRoutesTestRouter()
|
||||||
|
|
||||||
|
for _, path := range []string{
|
||||||
|
"/v1/images/generations",
|
||||||
|
"/v1/images/edits",
|
||||||
|
"/images/generations",
|
||||||
|
"/images/edits",
|
||||||
|
} {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-image-2","prompt":"draw a cat"}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI images handler", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -911,6 +911,34 @@ func (a *Account) GetChatGPTAccountID() string {
|
|||||||
return a.GetCredential("chatgpt_account_id")
|
return a.GetCredential("chatgpt_account_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetOpenAIDeviceID() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(a.GetExtraString("openai_device_id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetOpenAISessionID() string {
|
||||||
|
if !a.IsOpenAIOAuth() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(a.GetExtraString("openai_session_id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
|
||||||
|
if !a.IsOpenAI() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch capability {
|
||||||
|
case OpenAIImagesCapabilityBasic:
|
||||||
|
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
|
||||||
|
case OpenAIImagesCapabilityNative:
|
||||||
|
return a.Type == AccountTypeAPIKey
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) GetChatGPTUserID() string {
|
func (a *Account) GetChatGPTUserID() string {
|
||||||
if !a.IsOpenAIOAuth() {
|
if !a.IsOpenAIOAuth() {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -61,6 +61,25 @@ type PricingInput struct {
|
|||||||
// 1. 获取基础定价(LiteLLM → Fallback)
|
// 1. 获取基础定价(LiteLLM → Fallback)
|
||||||
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
|
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
|
||||||
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
|
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
|
||||||
|
var chPricing *ChannelModelPricing
|
||||||
|
if input.GroupID != nil && r.channelService != nil {
|
||||||
|
chPricing = r.channelService.GetChannelModelPricing(ctx, *input.GroupID, input.Model)
|
||||||
|
if chPricing != nil {
|
||||||
|
mode := chPricing.BillingMode
|
||||||
|
if mode == "" {
|
||||||
|
mode = BillingModeToken
|
||||||
|
}
|
||||||
|
if mode == BillingModePerRequest || mode == BillingModeImage {
|
||||||
|
resolved := &ResolvedPricing{
|
||||||
|
Mode: mode,
|
||||||
|
Source: PricingSourceChannel,
|
||||||
|
}
|
||||||
|
r.applyRequestTierOverrides(chPricing, resolved)
|
||||||
|
return resolved
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 获取基础定价
|
// 1. 获取基础定价
|
||||||
basePricing, source := r.resolveBasePricing(input.Model)
|
basePricing, source := r.resolveBasePricing(input.Model)
|
||||||
|
|
||||||
@@ -72,7 +91,10 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. 如果有 GroupID,尝试渠道覆盖
|
// 2. 如果有 GroupID,尝试渠道覆盖
|
||||||
if input.GroupID != nil {
|
if chPricing != nil {
|
||||||
|
resolved.Source = PricingSourceChannel
|
||||||
|
r.applyTokenOverrides(chPricing, resolved)
|
||||||
|
} else if input.GroupID != nil {
|
||||||
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
|
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,13 +38,14 @@ var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSch
|
|||||||
var openAIAdvancedSchedulerSettingSF singleflight.Group
|
var openAIAdvancedSchedulerSettingSF singleflight.Group
|
||||||
|
|
||||||
type OpenAIAccountScheduleRequest struct {
|
type OpenAIAccountScheduleRequest struct {
|
||||||
GroupID *int64
|
GroupID *int64
|
||||||
SessionHash string
|
SessionHash string
|
||||||
StickyAccountID int64
|
StickyAccountID int64
|
||||||
PreviousResponseID string
|
PreviousResponseID string
|
||||||
RequestedModel string
|
RequestedModel string
|
||||||
RequiredTransport OpenAIUpstreamTransport
|
RequiredTransport OpenAIUpstreamTransport
|
||||||
ExcludedIDs map[int64]struct{}
|
RequiredImageCapability OpenAIImagesCapability
|
||||||
|
ExcludedIDs map[int64]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIAccountScheduleDecision struct {
|
type OpenAIAccountScheduleDecision struct {
|
||||||
@@ -340,7 +341,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
|||||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
if !s.isAccountRequestCompatible(account, req) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||||
@@ -616,7 +617,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
if !s.isAccountRequestCompatible(account, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||||
@@ -722,11 +723,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
for i := 0; i < len(selectionOrder); i++ {
|
for i := 0; i < len(selectionOrder); i++ {
|
||||||
candidate := selectionOrder[i]
|
candidate := selectionOrder[i]
|
||||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
|
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||||
@@ -749,7 +750,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
|||||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||||
for _, candidate := range selectionOrder {
|
for _, candidate := range selectionOrder {
|
||||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
@@ -776,6 +777,16 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
|
|||||||
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
|
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
|
||||||
|
if account == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
||||||
if s == nil || s.stats == nil {
|
if s == nil || s.stats == nil {
|
||||||
return
|
return
|
||||||
@@ -894,14 +905,59 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
|||||||
requestedModel string,
|
requestedModel string,
|
||||||
excludedIDs map[int64]struct{},
|
excludedIDs map[int64]struct{},
|
||||||
requiredTransport OpenAIUpstreamTransport,
|
requiredTransport OpenAIUpstreamTransport,
|
||||||
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
|
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
||||||
|
ctx context.Context,
|
||||||
|
groupID *int64,
|
||||||
|
sessionHash string,
|
||||||
|
requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
requiredCapability OpenAIImagesCapability,
|
||||||
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
|
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||||
|
ctx context.Context,
|
||||||
|
groupID *int64,
|
||||||
|
previousResponseID string,
|
||||||
|
sessionHash string,
|
||||||
|
requestedModel string,
|
||||||
|
excludedIDs map[int64]struct{},
|
||||||
|
requiredTransport OpenAIUpstreamTransport,
|
||||||
|
requiredImageCapability OpenAIImagesCapability,
|
||||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||||
decision := OpenAIAccountScheduleDecision{}
|
decision := OpenAIAccountScheduleDecision{}
|
||||||
scheduler := s.getOpenAIAccountScheduler(ctx)
|
scheduler := s.getOpenAIAccountScheduler(ctx)
|
||||||
if scheduler == nil {
|
if scheduler == nil {
|
||||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||||
return selection, decision, err
|
for {
|
||||||
|
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, decision, err
|
||||||
|
}
|
||||||
|
if selection == nil || selection.Account == nil {
|
||||||
|
return selection, decision, nil
|
||||||
|
}
|
||||||
|
if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) {
|
||||||
|
return selection, decision, nil
|
||||||
|
}
|
||||||
|
if selection.ReleaseFunc != nil {
|
||||||
|
selection.ReleaseFunc()
|
||||||
|
}
|
||||||
|
if effectiveExcludedIDs == nil {
|
||||||
|
effectiveExcludedIDs = make(map[int64]struct{})
|
||||||
|
}
|
||||||
|
if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
|
||||||
|
return nil, decision, ErrNoAvailableAccounts
|
||||||
|
}
|
||||||
|
effectiveExcludedIDs[selection.Account.ID] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||||
@@ -937,13 +993,14 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
|
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
SessionHash: sessionHash,
|
SessionHash: sessionHash,
|
||||||
StickyAccountID: stickyAccountID,
|
StickyAccountID: stickyAccountID,
|
||||||
PreviousResponseID: previousResponseID,
|
PreviousResponseID: previousResponseID,
|
||||||
RequestedModel: requestedModel,
|
RequestedModel: requestedModel,
|
||||||
RequiredTransport: requiredTransport,
|
RequiredTransport: requiredTransport,
|
||||||
ExcludedIDs: excludedIDs,
|
RequiredImageCapability: requiredImageCapability,
|
||||||
|
ExcludedIDs: excludedIDs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1070,3 +1070,31 @@ func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *t
|
|||||||
require.Equal(t, 0, userRepo.deductCalls)
|
require.Equal(t, 0, userRepo.deductCalls)
|
||||||
require.Equal(t, 0, subRepo.incrementCalls)
|
require.Equal(t, 0, subRepo.incrementCalls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "resp_image_only_usage",
|
||||||
|
Model: "gpt-image-2",
|
||||||
|
ImageCount: 2,
|
||||||
|
ImageSize: "1K",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 1007},
|
||||||
|
User: &User{ID: 2007},
|
||||||
|
Account: &Account{ID: 3007},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.ImageSize)
|
||||||
|
require.Equal(t, "1K", *usageRepo.lastLog.ImageSize)
|
||||||
|
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||||
|
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||||
|
}
|
||||||
|
|||||||
@@ -233,6 +233,8 @@ type OpenAIForwardResult struct {
|
|||||||
ResponseHeaders http.Header
|
ResponseHeaders http.Header
|
||||||
Duration time.Duration
|
Duration time.Duration
|
||||||
FirstTokenMs *int
|
FirstTokenMs *int
|
||||||
|
ImageCount int
|
||||||
|
ImageSize string
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIWSRetryMetricsSnapshot struct {
|
type OpenAIWSRetryMetricsSnapshot struct {
|
||||||
@@ -3889,6 +3891,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
|||||||
usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
|
usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
|
||||||
usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
|
usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
|
||||||
usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
||||||
|
usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int())
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||||
@@ -3900,11 +3903,13 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
|||||||
"usage.input_tokens",
|
"usage.input_tokens",
|
||||||
"usage.output_tokens",
|
"usage.output_tokens",
|
||||||
"usage.input_tokens_details.cached_tokens",
|
"usage.input_tokens_details.cached_tokens",
|
||||||
|
"usage.output_tokens_details.image_tokens",
|
||||||
)
|
)
|
||||||
return OpenAIUsage{
|
return OpenAIUsage{
|
||||||
InputTokens: int(values[0].Int()),
|
InputTokens: int(values[0].Int()),
|
||||||
OutputTokens: int(values[1].Int()),
|
OutputTokens: int(values[1].Int()),
|
||||||
CacheReadInputTokens: int(values[2].Int()),
|
CacheReadInputTokens: int(values[2].Int()),
|
||||||
|
ImageOutputTokens: int(values[3].Int()),
|
||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4397,7 +4402,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
|
|
||||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
|
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
|
||||||
|
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4451,21 +4457,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
if result.ServiceTier != nil {
|
if result.ServiceTier != nil {
|
||||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||||
}
|
}
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
|
||||||
gid := apiKey.Group.ID
|
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
|
||||||
Ctx: ctx,
|
|
||||||
Model: billingModel,
|
|
||||||
GroupID: &gid,
|
|
||||||
Tokens: tokens,
|
|
||||||
RequestCount: 1,
|
|
||||||
RateMultiplier: multiplier,
|
|
||||||
ServiceTier: serviceTier,
|
|
||||||
Resolver: s.resolver,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cost = &CostBreakdown{ActualCost: 0}
|
cost = &CostBreakdown{ActualCost: 0}
|
||||||
}
|
}
|
||||||
@@ -4505,6 +4497,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||||
|
ImageCount: result.ImageCount,
|
||||||
|
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||||
}
|
}
|
||||||
if cost != nil {
|
if cost != nil {
|
||||||
usageLog.InputCost = cost.InputCost
|
usageLog.InputCost = cost.InputCost
|
||||||
@@ -4530,6 +4524,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
if cost != nil && cost.BillingMode != "" {
|
if cost != nil && cost.BillingMode != "" {
|
||||||
billingMode := cost.BillingMode
|
billingMode := cost.BillingMode
|
||||||
usageLog.BillingMode = &billingMode
|
usageLog.BillingMode = &billingMode
|
||||||
|
} else if result.ImageCount > 0 {
|
||||||
|
billingMode := string(BillingModeImage)
|
||||||
|
usageLog.BillingMode = &billingMode
|
||||||
} else {
|
} else {
|
||||||
billingMode := string(BillingModeToken)
|
billingMode := string(BillingModeToken)
|
||||||
usageLog.BillingMode = &billingMode
|
usageLog.BillingMode = &billingMode
|
||||||
@@ -4589,6 +4586,125 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
||||||
|
ctx context.Context,
|
||||||
|
result *OpenAIForwardResult,
|
||||||
|
apiKey *APIKey,
|
||||||
|
billingModel string,
|
||||||
|
multiplier float64,
|
||||||
|
tokens UsageTokens,
|
||||||
|
serviceTier string,
|
||||||
|
) (*CostBreakdown, error) {
|
||||||
|
if result != nil && result.ImageCount > 0 {
|
||||||
|
if hasOpenAIImageUsageTokens(result) {
|
||||||
|
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
|
||||||
|
if err == nil {
|
||||||
|
return cost, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
||||||
|
}
|
||||||
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
|
gid := apiKey.Group.ID
|
||||||
|
return s.billingService.CalculateCostUnified(CostInput{
|
||||||
|
Ctx: ctx,
|
||||||
|
Model: billingModel,
|
||||||
|
GroupID: &gid,
|
||||||
|
Tokens: tokens,
|
||||||
|
RequestCount: 1,
|
||||||
|
RateMultiplier: multiplier,
|
||||||
|
ServiceTier: serviceTier,
|
||||||
|
Resolver: s.resolver,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost(
|
||||||
|
ctx context.Context,
|
||||||
|
apiKey *APIKey,
|
||||||
|
billingModel string,
|
||||||
|
multiplier float64,
|
||||||
|
tokens UsageTokens,
|
||||||
|
serviceTier string,
|
||||||
|
sizeTier string,
|
||||||
|
) (*CostBreakdown, error) {
|
||||||
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
|
gid := apiKey.Group.ID
|
||||||
|
return s.billingService.CalculateCostUnified(CostInput{
|
||||||
|
Ctx: ctx,
|
||||||
|
Model: billingModel,
|
||||||
|
GroupID: &gid,
|
||||||
|
Tokens: tokens,
|
||||||
|
RequestCount: 1,
|
||||||
|
SizeTier: sizeTier,
|
||||||
|
RateMultiplier: multiplier,
|
||||||
|
ServiceTier: serviceTier,
|
||||||
|
Resolver: s.resolver,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
||||||
|
ctx context.Context,
|
||||||
|
billingModel string,
|
||||||
|
apiKey *APIKey,
|
||||||
|
result *OpenAIForwardResult,
|
||||||
|
multiplier float64,
|
||||||
|
) *CostBreakdown {
|
||||||
|
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil {
|
||||||
|
gid := apiKey.Group.ID
|
||||||
|
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
||||||
|
Ctx: ctx,
|
||||||
|
Model: billingModel,
|
||||||
|
GroupID: &gid,
|
||||||
|
RequestCount: 1,
|
||||||
|
SizeTier: result.ImageSize,
|
||||||
|
RateMultiplier: multiplier,
|
||||||
|
Resolver: s.resolver,
|
||||||
|
Resolved: resolved,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var groupConfig *ImagePriceConfig
|
||||||
|
if apiKey != nil && apiKey.Group != nil {
|
||||||
|
groupConfig = &ImagePriceConfig{
|
||||||
|
Price1K: apiKey.Group.ImagePrice1K,
|
||||||
|
Price2K: apiKey.Group.ImagePrice2K,
|
||||||
|
Price4K: apiKey.Group.ImagePrice4K,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||||
|
if s.resolver == nil || apiKey == nil || apiKey.Group == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
gid := apiKey.Group.ID
|
||||||
|
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||||
|
if resolved.Source == PricingSourceChannel {
|
||||||
|
return resolved
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool {
|
||||||
|
if result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return result.Usage.InputTokens > 0 ||
|
||||||
|
result.Usage.OutputTokens > 0 ||
|
||||||
|
result.Usage.CacheCreationInputTokens > 0 ||
|
||||||
|
result.Usage.CacheReadInputTokens > 0 ||
|
||||||
|
result.Usage.ImageOutputTokens > 0
|
||||||
|
}
|
||||||
|
|
||||||
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||||
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||||
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||||
|
|||||||
2013
backend/internal/service/openai_images.go
Normal file
2013
backend/internal/service/openai_images.go
Normal file
File diff suppressed because it is too large
Load Diff
105
backend/internal/service/openai_images_test.go
Normal file
105
backend/internal/service/openai_images_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "/v1/images/generations", parsed.Endpoint)
|
||||||
|
require.Equal(t, "gpt-image-2", parsed.Model)
|
||||||
|
require.Equal(t, "draw a cat", parsed.Prompt)
|
||||||
|
require.True(t, parsed.Stream)
|
||||||
|
require.Equal(t, "1024x1024", parsed.Size)
|
||||||
|
require.Equal(t, "1K", parsed.SizeTier)
|
||||||
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
|
require.False(t, parsed.Multipart)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||||
|
require.NoError(t, writer.WriteField("prompt", "replace background"))
|
||||||
|
require.NoError(t, writer.WriteField("size", "1536x1024"))
|
||||||
|
part, err := writer.CreateFormFile("image", "source.png")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = part.Write([]byte("fake-image-bytes"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "/v1/images/edits", parsed.Endpoint)
|
||||||
|
require.True(t, parsed.Multipart)
|
||||||
|
require.Equal(t, "gpt-image-2", parsed.Model)
|
||||||
|
require.Equal(t, "replace background", parsed.Prompt)
|
||||||
|
require.Equal(t, "1536x1024", parsed.Size)
|
||||||
|
require.Equal(t, "2K", parsed.SizeTier)
|
||||||
|
require.Len(t, parsed.Uploads, 1)
|
||||||
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"prompt":"draw a cat"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, "gpt-image-2", parsed.Model)
|
||||||
|
require.Equal(t, OpenAIImagesCapabilityBasic, parsed.RequiredCapability)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNativeCapability(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
body := []byte(`{"prompt":"draw a cat","size":"1024x1024"}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, parsed)
|
||||||
|
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||||
|
}
|
||||||
@@ -388,7 +388,7 @@ func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDeta
|
|||||||
func detectOpsRetryType(path string) opsRetryRequestType {
|
func detectOpsRetryType(path string) opsRetryRequestType {
|
||||||
p := strings.ToLower(strings.TrimSpace(path))
|
p := strings.ToLower(strings.TrimSpace(path))
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(p, "/responses"):
|
case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
|
||||||
return opsRetryTypeOpenAI
|
return opsRetryTypeOpenAI
|
||||||
case strings.Contains(p, "/v1beta/"):
|
case strings.Contains(p, "/v1beta/"):
|
||||||
return opsRetryTypeGeminiV1B
|
return opsRetryTypeGeminiV1B
|
||||||
|
|||||||
@@ -305,7 +305,8 @@ func shouldBypassEmbeddedFrontend(path string) bool {
|
|||||||
strings.HasPrefix(trimmed, "/setup/") ||
|
strings.HasPrefix(trimmed, "/setup/") ||
|
||||||
trimmed == "/health" ||
|
trimmed == "/health" ||
|
||||||
trimmed == "/responses" ||
|
trimmed == "/responses" ||
|
||||||
strings.HasPrefix(trimmed, "/responses/")
|
strings.HasPrefix(trimmed, "/responses/") ||
|
||||||
|
strings.HasPrefix(trimmed, "/images/")
|
||||||
}
|
}
|
||||||
|
|
||||||
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
||||||
|
|||||||
Reference in New Issue
Block a user