feat(openai): 同步生图 API 支持并接入图片计费调度
- 同步 OpenAI 图片生成与编辑接口 - 接入图片请求解析、账号调度、转发与用量记录 - 接入图片计费与图片用量落库 - 限制 OAuth 生图仅支持无显式模型和尺寸的基础请求
This commit is contained in:
@@ -911,6 +911,34 @@ func (a *Account) GetChatGPTAccountID() string {
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIDeviceID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(a.GetExtraString("openai_device_id"))
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAISessionID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(a.GetExtraString("openai_session_id"))
|
||||
}
|
||||
|
||||
func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
|
||||
if !a.IsOpenAI() {
|
||||
return false
|
||||
}
|
||||
switch capability {
|
||||
case OpenAIImagesCapabilityBasic:
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
|
||||
case OpenAIImagesCapabilityNative:
|
||||
return a.Type == AccountTypeAPIKey
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
|
||||
@@ -61,6 +61,25 @@ type PricingInput struct {
|
||||
// 1. 获取基础定价(LiteLLM → Fallback)
|
||||
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
|
||||
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
|
||||
var chPricing *ChannelModelPricing
|
||||
if input.GroupID != nil && r.channelService != nil {
|
||||
chPricing = r.channelService.GetChannelModelPricing(ctx, *input.GroupID, input.Model)
|
||||
if chPricing != nil {
|
||||
mode := chPricing.BillingMode
|
||||
if mode == "" {
|
||||
mode = BillingModeToken
|
||||
}
|
||||
if mode == BillingModePerRequest || mode == BillingModeImage {
|
||||
resolved := &ResolvedPricing{
|
||||
Mode: mode,
|
||||
Source: PricingSourceChannel,
|
||||
}
|
||||
r.applyRequestTierOverrides(chPricing, resolved)
|
||||
return resolved
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. 获取基础定价
|
||||
basePricing, source := r.resolveBasePricing(input.Model)
|
||||
|
||||
@@ -72,7 +91,10 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput)
|
||||
}
|
||||
|
||||
// 2. 如果有 GroupID,尝试渠道覆盖
|
||||
if input.GroupID != nil {
|
||||
if chPricing != nil {
|
||||
resolved.Source = PricingSourceChannel
|
||||
r.applyTokenOverrides(chPricing, resolved)
|
||||
} else if input.GroupID != nil {
|
||||
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
|
||||
}
|
||||
|
||||
|
||||
@@ -38,13 +38,14 @@ var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSch
|
||||
var openAIAdvancedSchedulerSettingSF singleflight.Group
|
||||
|
||||
type OpenAIAccountScheduleRequest struct {
|
||||
GroupID *int64
|
||||
SessionHash string
|
||||
StickyAccountID int64
|
||||
PreviousResponseID string
|
||||
RequestedModel string
|
||||
RequiredTransport OpenAIUpstreamTransport
|
||||
ExcludedIDs map[int64]struct{}
|
||||
GroupID *int64
|
||||
SessionHash string
|
||||
StickyAccountID int64
|
||||
PreviousResponseID string
|
||||
RequestedModel string
|
||||
RequiredTransport OpenAIUpstreamTransport
|
||||
RequiredImageCapability OpenAIImagesCapability
|
||||
ExcludedIDs map[int64]struct{}
|
||||
}
|
||||
|
||||
type OpenAIAccountScheduleDecision struct {
|
||||
@@ -340,7 +341,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
|
||||
return nil, nil
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
if !s.isAccountRequestCompatible(account, req) {
|
||||
return nil, nil
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
@@ -616,7 +617,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
if !s.isAccountRequestCompatible(account, req) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
|
||||
@@ -722,11 +723,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
@@ -749,7 +750,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
|
||||
for _, candidate := range selectionOrder {
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
|
||||
continue
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
@@ -776,6 +777,16 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
|
||||
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
if s == nil || s.stats == nil {
|
||||
return
|
||||
@@ -894,14 +905,59 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredCapability OpenAIImagesCapability,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) selectAccountWithScheduler(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
previousResponseID string,
|
||||
sessionHash string,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
requiredImageCapability OpenAIImagesCapability,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
scheduler := s.getOpenAIAccountScheduler(ctx)
|
||||
if scheduler == nil {
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
return selection, decision, err
|
||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||
for {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs)
|
||||
if err != nil {
|
||||
return nil, decision, err
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
return selection, decision, nil
|
||||
}
|
||||
if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) {
|
||||
return selection, decision, nil
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if effectiveExcludedIDs == nil {
|
||||
effectiveExcludedIDs = make(map[int64]struct{})
|
||||
}
|
||||
if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
|
||||
return nil, decision, ErrNoAvailableAccounts
|
||||
}
|
||||
effectiveExcludedIDs[selection.Account.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
|
||||
@@ -937,13 +993,14 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
}
|
||||
|
||||
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
|
||||
GroupID: groupID,
|
||||
SessionHash: sessionHash,
|
||||
StickyAccountID: stickyAccountID,
|
||||
PreviousResponseID: previousResponseID,
|
||||
RequestedModel: requestedModel,
|
||||
RequiredTransport: requiredTransport,
|
||||
ExcludedIDs: excludedIDs,
|
||||
GroupID: groupID,
|
||||
SessionHash: sessionHash,
|
||||
StickyAccountID: stickyAccountID,
|
||||
PreviousResponseID: previousResponseID,
|
||||
RequestedModel: requestedModel,
|
||||
RequiredTransport: requiredTransport,
|
||||
RequiredImageCapability: requiredImageCapability,
|
||||
ExcludedIDs: excludedIDs,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1070,3 +1070,31 @@ func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *t
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_image_only_usage",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 2,
|
||||
ImageSize: "1K",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 1007},
|
||||
User: &User{ID: 2007},
|
||||
Account: &Account{ID: 3007},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, 2, usageRepo.lastLog.ImageCount)
|
||||
require.NotNil(t, usageRepo.lastLog.ImageSize)
|
||||
require.Equal(t, "1K", *usageRepo.lastLog.ImageSize)
|
||||
require.NotNil(t, usageRepo.lastLog.BillingMode)
|
||||
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
|
||||
}
|
||||
|
||||
@@ -233,6 +233,8 @@ type OpenAIForwardResult struct {
|
||||
ResponseHeaders http.Header
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
ImageCount int
|
||||
ImageSize string
|
||||
}
|
||||
|
||||
type OpenAIWSRetryMetricsSnapshot struct {
|
||||
@@ -3889,6 +3891,7 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
|
||||
usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
|
||||
usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
||||
usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int())
|
||||
}
|
||||
|
||||
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||
@@ -3900,11 +3903,13 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
|
||||
"usage.input_tokens",
|
||||
"usage.output_tokens",
|
||||
"usage.input_tokens_details.cached_tokens",
|
||||
"usage.output_tokens_details.image_tokens",
|
||||
)
|
||||
return OpenAIUsage{
|
||||
InputTokens: int(values[0].Int()),
|
||||
OutputTokens: int(values[1].Int()),
|
||||
CacheReadInputTokens: int(values[2].Int()),
|
||||
ImageOutputTokens: int(values[3].Int()),
|
||||
}, true
|
||||
}
|
||||
|
||||
@@ -4397,7 +4402,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
|
||||
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
|
||||
result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
|
||||
result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4451,21 +4457,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if result.ServiceTier != nil {
|
||||
serviceTier = strings.TrimSpace(*result.ServiceTier)
|
||||
}
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
ServiceTier: serviceTier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
} else {
|
||||
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
}
|
||||
cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
@@ -4505,6 +4497,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||
}
|
||||
if cost != nil {
|
||||
usageLog.InputCost = cost.InputCost
|
||||
@@ -4530,6 +4524,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else if result.ImageCount > 0 {
|
||||
billingMode := string(BillingModeImage)
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := string(BillingModeToken)
|
||||
usageLog.BillingMode = &billingMode
|
||||
@@ -4589,6 +4586,125 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
|
||||
ctx context.Context,
|
||||
result *OpenAIForwardResult,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
tokens UsageTokens,
|
||||
serviceTier string,
|
||||
) (*CostBreakdown, error) {
|
||||
if result != nil && result.ImageCount > 0 {
|
||||
if hasOpenAIImageUsageTokens(result) {
|
||||
cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize)
|
||||
if err == nil {
|
||||
return cost, nil
|
||||
}
|
||||
}
|
||||
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
|
||||
}
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
return s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
RateMultiplier: multiplier,
|
||||
ServiceTier: serviceTier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
}
|
||||
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost(
|
||||
ctx context.Context,
|
||||
apiKey *APIKey,
|
||||
billingModel string,
|
||||
multiplier float64,
|
||||
tokens UsageTokens,
|
||||
serviceTier string,
|
||||
sizeTier string,
|
||||
) (*CostBreakdown, error) {
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
return s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
Tokens: tokens,
|
||||
RequestCount: 1,
|
||||
SizeTier: sizeTier,
|
||||
RateMultiplier: multiplier,
|
||||
ServiceTier: serviceTier,
|
||||
Resolver: s.resolver,
|
||||
})
|
||||
}
|
||||
return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) calculateOpenAIImageCost(
|
||||
ctx context.Context,
|
||||
billingModel string,
|
||||
apiKey *APIKey,
|
||||
result *OpenAIForwardResult,
|
||||
multiplier float64,
|
||||
) *CostBreakdown {
|
||||
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil {
|
||||
gid := apiKey.Group.ID
|
||||
cost, err := s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
RequestCount: 1,
|
||||
SizeTier: result.ImageSize,
|
||||
RateMultiplier: multiplier,
|
||||
Resolver: s.resolver,
|
||||
Resolved: resolved,
|
||||
})
|
||||
if err == nil {
|
||||
return cost
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err)
|
||||
}
|
||||
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey != nil && apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
|
||||
if s.resolver == nil || apiKey == nil || apiKey.Group == nil {
|
||||
return nil
|
||||
}
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
return resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return result.Usage.InputTokens > 0 ||
|
||||
result.Usage.OutputTokens > 0 ||
|
||||
result.Usage.CacheCreationInputTokens > 0 ||
|
||||
result.Usage.CacheReadInputTokens > 0 ||
|
||||
result.Usage.ImageOutputTokens > 0
|
||||
}
|
||||
|
||||
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
|
||||
2017
backend/internal/service/openai_images.go
Normal file
2017
backend/internal/service/openai_images.go
Normal file
File diff suppressed because it is too large
Load Diff
105
backend/internal/service/openai_images_test.go
Normal file
105
backend/internal/service/openai_images_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, "/v1/images/generations", parsed.Endpoint)
|
||||
require.Equal(t, "gpt-image-2", parsed.Model)
|
||||
require.Equal(t, "draw a cat", parsed.Prompt)
|
||||
require.True(t, parsed.Stream)
|
||||
require.Equal(t, "1024x1024", parsed.Size)
|
||||
require.Equal(t, "1K", parsed.SizeTier)
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
require.False(t, parsed.Multipart)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.NoError(t, writer.WriteField("model", "gpt-image-2"))
|
||||
require.NoError(t, writer.WriteField("prompt", "replace background"))
|
||||
require.NoError(t, writer.WriteField("size", "1536x1024"))
|
||||
part, err := writer.CreateFormFile("image", "source.png")
|
||||
require.NoError(t, err)
|
||||
_, err = part.Write([]byte("fake-image-bytes"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, "/v1/images/edits", parsed.Endpoint)
|
||||
require.True(t, parsed.Multipart)
|
||||
require.Equal(t, "gpt-image-2", parsed.Model)
|
||||
require.Equal(t, "replace background", parsed.Prompt)
|
||||
require.Equal(t, "1536x1024", parsed.Size)
|
||||
require.Equal(t, "2K", parsed.SizeTier)
|
||||
require.Len(t, parsed.Uploads, 1)
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"prompt":"draw a cat"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, "gpt-image-2", parsed.Model)
|
||||
require.Equal(t, OpenAIImagesCapabilityBasic, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNativeCapability(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"prompt":"draw a cat","size":"1024x1024"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDeta
|
||||
func detectOpsRetryType(path string) opsRetryRequestType {
|
||||
p := strings.ToLower(strings.TrimSpace(path))
|
||||
switch {
|
||||
case strings.Contains(p, "/responses"):
|
||||
case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
|
||||
return opsRetryTypeOpenAI
|
||||
case strings.Contains(p, "/v1beta/"):
|
||||
return opsRetryTypeGeminiV1B
|
||||
|
||||
Reference in New Issue
Block a user