sync: bring over remaining release/custom-0.1.115 changes
- Extract PublicSettingsInjectionPayload named struct with drift test - Add channel_monitor_default_interval_seconds to SSR injection - Add image_output_price to SupportedModelChip - Simplify AppSidebar buildSelfNavItems (admins see available channels) - Add gateway WARN logs for 503 no-available-accounts branches - Wire ChannelMonitorRunner into provideCleanup for graceful shutdown - Add migrations 130/131 (CC template userid fix + mimicry field cleanup) - Clean up fork-only features (sora, claude max simulation, client affinity) - Remove ~320 obsolete i18n keys - Add codexUsage utility, WechatServiceButton, BulkEditAccountModal - Tidy go.sum
This commit is contained in:
@@ -30,6 +30,10 @@ type AccountRepository interface {
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号
|
||||
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||
// CountByTLSFingerprintProfile 按 TLS 指纹模板 ID 聚合每个模板当前被多少账号绑定。
|
||||
// 返回 map[profile_id]count;未绑定任何账号的 profile 不出现在 map 中。
|
||||
// 查询走 108_add_tls_fingerprint_profile_id_index.sql 的表达式索引。
|
||||
CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error)
|
||||
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||
// for all accounts that have been synced from CRS.
|
||||
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
|
||||
|
||||
@@ -58,6 +58,10 @@ func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, valu
|
||||
panic("unexpected FindByExtraField call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
panic("unexpected CountByTLSFingerprintProfile call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
panic("unexpected ListCRSAccountIDs call")
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||
s.getByIDsCalled = true
|
||||
s.getByIDsIDs = append([]int64{}, ids...)
|
||||
@@ -63,16 +73,6 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
|
||||
@@ -170,11 +170,11 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _
|
||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
return c.usersLoadBatch, c.usersLoadErr
|
||||
}
|
||||
|
||||
func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) {
|
||||
v, exists := c.Get(OpsSkipPassthroughKey)
|
||||
assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true")
|
||||
boolVal, ok := v.(bool)
|
||||
assert.True(t, ok, "value should be bool")
|
||||
assert.True(t, ok, "value should be a bool")
|
||||
assert.True(t, boolVal)
|
||||
}
|
||||
|
||||
|
||||
@@ -110,13 +110,12 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
},
|
||||
{
|
||||
// Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
|
||||
// second hit 仍然返回 TempUnscheduled。
|
||||
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
|
||||
// Gemini OAuth 401 second hit 会升级为 error(返回 None,交由默认错误逻辑处理)。
|
||||
name: "temp_unschedulable_401_second_hit_gemini_escalates",
|
||||
account: &Account{
|
||||
ID: 15,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Platform: PlatformGemini, // 非 Antigravity 平台 401 second hit 升级
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
@@ -131,7 +130,29 @@ func TestCheckErrorPolicy(t *testing.T) {
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled,
|
||||
expected: ErrorPolicyNone, // Gemini 401 second hit 升级为 error
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_401_antigravity_no_escalation",
|
||||
account: &Account{
|
||||
ID: 16,
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity, // Antigravity 跳过 401 升级,由 rules 正常处理
|
||||
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": float64(401),
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": float64(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
statusCode: 401,
|
||||
body: []byte(`unauthorized`),
|
||||
expected: ErrorPolicyTempUnscheduled, // Antigravity 不升级,继续走规则匹配
|
||||
},
|
||||
{
|
||||
name: "temp_unschedulable_body_miss_returns_none",
|
||||
|
||||
@@ -143,7 +143,6 @@ func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, g
|
||||
func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
s.listByGroupCalls.Add(1)
|
||||
if s.err != nil {
|
||||
|
||||
@@ -82,6 +82,10 @@ func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key s
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@ func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key str
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -781,7 +781,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Acco
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
if req.RequestedModel != "" && !account.IsOpenAIPassthroughEnabled() && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
|
||||
|
||||
@@ -187,13 +187,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
|
||||
}
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
return model
|
||||
}
|
||||
|
||||
modelID := model
|
||||
if strings.Contains(modelID, "/") {
|
||||
@@ -235,78 +231,6 @@ func normalizeCodexModel(model string) string {
|
||||
return "gpt-5.4"
|
||||
}
|
||||
|
||||
func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, rawTool := range tools {
|
||||
toolMap, ok := rawTool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, rawTool := range tools {
|
||||
toolMap, ok := rawTool.(map[string]any)
|
||||
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
|
||||
continue
|
||||
}
|
||||
if _, ok := toolMap["output_format"]; !ok {
|
||||
if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" {
|
||||
toolMap["output_format"] = value
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["output_compression"]; !ok {
|
||||
if value, exists := toolMap["compression"]; exists && value != nil {
|
||||
toolMap["output_compression"] = value
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["format"]; ok {
|
||||
delete(toolMap, "format")
|
||||
modified = true
|
||||
}
|
||||
if _, ok := toolMap["compression"]; ok {
|
||||
delete(toolMap, "compression")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
return modified
|
||||
}
|
||||
|
||||
func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
|
||||
if !hasOpenAIImageGenerationTool(reqBody) {
|
||||
return nil
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if !isOpenAIImageGenerationModel(model) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model)
|
||||
}
|
||||
|
||||
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
|
||||
if account == nil || account.Type == AccountTypeOAuth {
|
||||
return normalizeCodexModel(model)
|
||||
|
||||
@@ -217,42 +217,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
||||
require.Equal(t, "bash", first["name"])
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) {
|
||||
reqBody := map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "image_generation",
|
||||
"format": "png",
|
||||
"compression": 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modified := normalizeOpenAIResponsesImageGenerationTools(reqBody)
|
||||
require.True(t, modified)
|
||||
|
||||
tools, ok := reqBody["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
first, ok := tools[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "png", first["output_format"])
|
||||
require.Equal(t, 60, first["output_compression"])
|
||||
_, hasFormat := first["format"]
|
||||
require.False(t, hasFormat)
|
||||
_, hasCompression := first["compression"]
|
||||
require.False(t, hasCompression)
|
||||
}
|
||||
|
||||
func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) {
|
||||
err := validateOpenAIResponsesImageModel(map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{"type": "image_generation"},
|
||||
},
|
||||
}, "gpt-image-2")
|
||||
|
||||
require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
|
||||
|
||||
@@ -151,23 +151,38 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
}
|
||||
logger.L().Debug("openai chat_completions: model mapping applied", logFields...)
|
||||
|
||||
if account.Type == AccountTypeOAuth {
|
||||
{
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
|
||||
}
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
if codexResult.NormalizedModel != "" {
|
||||
upstreamModel = codexResult.NormalizedModel
|
||||
modified := false
|
||||
if account.Type == AccountTypeOAuth {
|
||||
codexResult := applyCodexOAuthTransform(reqBody, false, false)
|
||||
modified = codexResult.Modified
|
||||
if codexResult.NormalizedModel != "" {
|
||||
upstreamModel = codexResult.NormalizedModel
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
} else {
|
||||
// 非 OAuth 账号也需要提取 system 消息并注入 instructions,
|
||||
// 否则上游 GPT-5/Codex 等模型会报 "Instructions are required"。
|
||||
if extractSystemMessagesFromInput(reqBody) {
|
||||
modified = true
|
||||
}
|
||||
if applyInstructions(reqBody, false) {
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
} else if promptCacheKey != "" {
|
||||
reqBody["prompt_cache_key"] = promptCacheKey
|
||||
}
|
||||
responsesBody, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||
if modified {
|
||||
responsesBody, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1503,7 +1503,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
if requestedModel != "" && !acc.IsOpenAIPassthroughEnabled() && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
|
||||
@@ -1665,7 +1665,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
|
||||
if requestedModel != "" && !fresh.IsOpenAIPassthroughEnabled() && !fresh.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
@@ -1935,12 +1935,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("instructions", "You are a helpful coding assistant.")
|
||||
}
|
||||
|
||||
if normalizeOpenAIResponsesImageGenerationTools(reqBody) {
|
||||
bodyModified = true
|
||||
disablePatch()
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
|
||||
}
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
billingModel := account.GetMappedModel(reqModel)
|
||||
if billingModel != reqModel {
|
||||
@@ -1950,26 +1944,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
markPatchSet("model", billingModel)
|
||||
}
|
||||
upstreamModel := billingModel
|
||||
if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil {
|
||||
setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": err.Error(),
|
||||
"param": "model",
|
||||
},
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if hasOpenAIImageGenerationTool(reqBody) {
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s",
|
||||
reqModel,
|
||||
upstreamModel,
|
||||
account.Type,
|
||||
)
|
||||
}
|
||||
|
||||
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
|
||||
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
|
||||
|
||||
@@ -45,11 +45,8 @@ const (
|
||||
openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare"
|
||||
openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements"
|
||||
|
||||
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
openAIImageRequirementsDiff = "0fffff"
|
||||
openAIImageLifecycleTimeout = 2 * time.Minute
|
||||
openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download
|
||||
openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part
|
||||
openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
openAIImageRequirementsDiff = "0fffff"
|
||||
)
|
||||
|
||||
type OpenAIImagesCapability string
|
||||
@@ -151,9 +148,6 @@ func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []b
|
||||
}
|
||||
|
||||
applyOpenAIImagesDefaults(req)
|
||||
if err := validateOpenAIImagesModel(req.Model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
|
||||
req.RequiredCapability = classifyOpenAIImagesCapability(req)
|
||||
return req, nil
|
||||
@@ -220,7 +214,7 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize))
|
||||
data, err := io.ReadAll(part)
|
||||
_ = part.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read multipart field %s: %w", name, err)
|
||||
@@ -301,21 +295,6 @@ func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) {
|
||||
req.Model = "gpt-image-2"
|
||||
}
|
||||
|
||||
func isOpenAIImageGenerationModel(model string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-")
|
||||
}
|
||||
|
||||
func validateOpenAIImagesModel(model string) error {
|
||||
model = strings.TrimSpace(model)
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
return nil
|
||||
}
|
||||
if model == "" {
|
||||
return fmt.Errorf("images endpoint requires an image model")
|
||||
}
|
||||
return fmt.Errorf("images endpoint requires an image model, got %q", model)
|
||||
}
|
||||
|
||||
func normalizeOpenAIImagesEndpointPath(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
switch {
|
||||
@@ -421,21 +400,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
|
||||
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||
requestModel = mapped
|
||||
}
|
||||
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamModel := account.GetMappedModel(requestModel)
|
||||
if err := validateOpenAIImagesModel(upstreamModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s",
|
||||
strings.TrimSpace(parsed.Model),
|
||||
upstreamModel,
|
||||
parsed.Endpoint,
|
||||
account.Type,
|
||||
)
|
||||
forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -794,17 +759,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
|
||||
requestModel = mapped
|
||||
}
|
||||
if err := validateOpenAIImagesModel(requestModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
|
||||
requestModel,
|
||||
parsed.Endpoint,
|
||||
account.Type,
|
||||
len(parsed.Uploads),
|
||||
)
|
||||
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
@@ -890,18 +844,8 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
return nil, err
|
||||
}
|
||||
pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil)
|
||||
logger.LegacyPrintf(
|
||||
"service.openai_gateway",
|
||||
"[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d",
|
||||
conversationID,
|
||||
len(pointerInfos),
|
||||
countOpenAIFileServicePointerInfos(pointerInfos),
|
||||
countOpenAIDirectImageAssets(pointerInfos),
|
||||
)
|
||||
lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout)
|
||||
defer releaseLifecycleCtx()
|
||||
if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) {
|
||||
polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID)
|
||||
polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID)
|
||||
if pollErr != nil {
|
||||
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr)
|
||||
}
|
||||
@@ -909,11 +853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
|
||||
}
|
||||
pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos)
|
||||
if len(pointerInfos) == 0 {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID)
|
||||
return nil, fmt.Errorf("openai image conversation returned no downloadable images")
|
||||
}
|
||||
|
||||
responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos)
|
||||
responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos)
|
||||
if err != nil {
|
||||
return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err)
|
||||
}
|
||||
@@ -1340,11 +1283,8 @@ func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMess
|
||||
}
|
||||
|
||||
type openAIImagePointerInfo struct {
|
||||
Pointer string
|
||||
DownloadURL string
|
||||
B64JSON string
|
||||
MimeType string
|
||||
Prompt string
|
||||
Pointer string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
type openAIImageToolMessage struct {
|
||||
@@ -1396,6 +1336,10 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
matches := openAIImagePointerMatches(body)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
prompt := ""
|
||||
for _, path := range []string{
|
||||
"message.metadata.dalle.prompt",
|
||||
@@ -1407,12 +1351,11 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
|
||||
break
|
||||
}
|
||||
}
|
||||
matches := openAIImagePointerMatches(body)
|
||||
out := make([]openAIImagePointerInfo, 0, len(matches))
|
||||
for _, pointer := range matches {
|
||||
out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
|
||||
}
|
||||
return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt))
|
||||
return out
|
||||
}
|
||||
|
||||
func openAIImagePointerMatches(body []byte) []string {
|
||||
@@ -1451,72 +1394,27 @@ func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []open
|
||||
seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
|
||||
out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
|
||||
for _, item := range existing {
|
||||
if key := item.identityKey(); key != "" {
|
||||
seen[key] = item
|
||||
}
|
||||
seen[item.Pointer] = item
|
||||
out = append(out, item)
|
||||
}
|
||||
for _, item := range next {
|
||||
key := item.identityKey()
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if existingItem, ok := seen[key]; ok {
|
||||
merged := mergeOpenAIImagePointerInfo(existingItem, item)
|
||||
if merged != existingItem {
|
||||
if existingItem, ok := seen[item.Pointer]; ok {
|
||||
if existingItem.Prompt == "" && item.Prompt != "" {
|
||||
for i := range out {
|
||||
if out[i].identityKey() == key {
|
||||
out[i] = merged
|
||||
if out[i].Pointer == item.Pointer {
|
||||
out[i].Prompt = item.Prompt
|
||||
break
|
||||
}
|
||||
}
|
||||
seen[key] = merged
|
||||
}
|
||||
continue
|
||||
}
|
||||
seen[key] = item
|
||||
seen[item.Pointer] = item
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (i openAIImagePointerInfo) identityKey() string {
|
||||
switch {
|
||||
case strings.TrimSpace(i.Pointer) != "":
|
||||
return "pointer:" + strings.TrimSpace(i.Pointer)
|
||||
case strings.TrimSpace(i.DownloadURL) != "":
|
||||
return "download:" + strings.TrimSpace(i.DownloadURL)
|
||||
case strings.TrimSpace(i.B64JSON) != "":
|
||||
b64 := strings.TrimSpace(i.B64JSON)
|
||||
if len(b64) > 64 {
|
||||
b64 = b64[:64]
|
||||
}
|
||||
return "b64:" + b64
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo {
|
||||
merged := existing
|
||||
if strings.TrimSpace(merged.Pointer) == "" {
|
||||
merged.Pointer = next.Pointer
|
||||
}
|
||||
if strings.TrimSpace(merged.DownloadURL) == "" {
|
||||
merged.DownloadURL = next.DownloadURL
|
||||
}
|
||||
if strings.TrimSpace(merged.B64JSON) == "" {
|
||||
merged.B64JSON = next.B64JSON
|
||||
}
|
||||
if strings.TrimSpace(merged.MimeType) == "" {
|
||||
merged.MimeType = next.MimeType
|
||||
}
|
||||
if strings.TrimSpace(merged.Prompt) == "" {
|
||||
merged.Prompt = next.Prompt
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item.Pointer, "file-service://") {
|
||||
@@ -1526,26 +1424,6 @@ func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int {
|
||||
count := 0
|
||||
for _, item := range items {
|
||||
if strings.HasPrefix(item.Pointer, "file-service://") {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int {
|
||||
count := 0
|
||||
for _, item := range items {
|
||||
if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo {
|
||||
if !hasOpenAIFileServicePointerInfos(items) {
|
||||
return items
|
||||
@@ -1713,7 +1591,11 @@ func buildOpenAIImageResponse(
|
||||
}
|
||||
items := make([]responseItem, 0, len(pointers))
|
||||
for _, pointer := range pointers {
|
||||
data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer)
|
||||
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -1733,136 +1615,6 @@ func buildOpenAIImageResponse(
|
||||
return body, len(items), nil
|
||||
}
|
||||
|
||||
func resolveOpenAIImageBytes(
|
||||
ctx context.Context,
|
||||
client *req.Client,
|
||||
headers http.Header,
|
||||
conversationID string,
|
||||
pointer openAIImagePointerInfo,
|
||||
) ([]byte, error) {
|
||||
if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" {
|
||||
return base64.StdEncoding.DecodeString(normalized)
|
||||
}
|
||||
if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" {
|
||||
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||
}
|
||||
if strings.TrimSpace(pointer.Pointer) == "" {
|
||||
return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data")
|
||||
}
|
||||
downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
|
||||
}
|
||||
|
||||
func normalizeOpenAIImageBase64(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
||||
if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) {
|
||||
raw = raw[idx+1:]
|
||||
}
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if _, err := base64.StdEncoding.DecodeString(raw); err != nil {
|
||||
return ""
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return nil
|
||||
}
|
||||
var decoded any
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return nil
|
||||
}
|
||||
var out []openAIImagePointerInfo
|
||||
walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out)
|
||||
return out
|
||||
}
|
||||
|
||||
func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) {
|
||||
switch value := node.(type) {
|
||||
case map[string]any:
|
||||
localPrompt := prompt
|
||||
for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} {
|
||||
if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" {
|
||||
localPrompt = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
}
|
||||
item := openAIImagePointerInfo{
|
||||
Prompt: localPrompt,
|
||||
Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]),
|
||||
DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]),
|
||||
B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]),
|
||||
MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]),
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"),
|
||||
strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"),
|
||||
isLikelyOpenAIImageDownloadURL(item.DownloadURL),
|
||||
normalizeOpenAIImageBase64(item.B64JSON) != "":
|
||||
*out = append(*out, item)
|
||||
}
|
||||
for _, child := range value {
|
||||
walkOpenAIImageInlineAssets(child, localPrompt, out)
|
||||
}
|
||||
case []any:
|
||||
for _, child := range value {
|
||||
walkOpenAIImageInlineAssets(child, prompt, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...any) string {
|
||||
for _, value := range values {
|
||||
if s, ok := value.(string); ok && strings.TrimSpace(s) != "" {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isLikelyOpenAIImageDownloadURL(raw string) bool {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "data:image/") {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
return strings.Contains(lower, "/download") ||
|
||||
strings.Contains(lower, ".png") ||
|
||||
strings.Contains(lower, ".jpg") ||
|
||||
strings.Contains(lower, ".jpeg") ||
|
||||
strings.Contains(lower, ".webp")
|
||||
}
|
||||
|
||||
func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
if timeout <= 0 {
|
||||
return base, func() {}
|
||||
}
|
||||
return context.WithTimeout(base, timeout)
|
||||
}
|
||||
|
||||
func fetchOpenAIImageDownloadURL(
|
||||
ctx context.Context,
|
||||
client *req.Client,
|
||||
@@ -1954,7 +1706,7 @@ func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers h
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, newOpenAIImageStatusError(resp, "download image bytes failed")
|
||||
}
|
||||
return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes))
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func handleOpenAIImageBackendError(resp *req.Response) error {
|
||||
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -104,56 +103,3 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNative
|
||||
require.NotNil(t, parsed)
|
||||
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.Nil(t, parsed)
|
||||
require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
|
||||
}
|
||||
|
||||
func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
|
||||
items := collectOpenAIImagePointers([]byte(`{
|
||||
"revised_prompt": "cat astronaut",
|
||||
"parts": [
|
||||
{"b64_json":"QUJD"},
|
||||
{"download_url":"https://files.example.com/image.png?sig=1"},
|
||||
{"asset_pointer":"file-service://file_123"}
|
||||
]
|
||||
}`))
|
||||
|
||||
require.Len(t, items, 3)
|
||||
var sawBase64, sawURL, sawPointer bool
|
||||
for _, item := range items {
|
||||
if item.B64JSON == "QUJD" {
|
||||
sawBase64 = true
|
||||
require.Equal(t, "cat astronaut", item.Prompt)
|
||||
}
|
||||
if item.DownloadURL == "https://files.example.com/image.png?sig=1" {
|
||||
sawURL = true
|
||||
}
|
||||
if item.Pointer == "file-service://file_123" {
|
||||
sawPointer = true
|
||||
}
|
||||
}
|
||||
require.True(t, sawBase64)
|
||||
require.True(t, sawURL)
|
||||
require.True(t, sawPointer)
|
||||
}
|
||||
|
||||
func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) {
|
||||
data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{
|
||||
B64JSON: "data:image/png;base64,QUJD",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ABC"), data)
|
||||
}
|
||||
|
||||
@@ -91,7 +91,6 @@ func TestNormalizeCodexModel(t *testing.T) {
|
||||
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
|
||||
"gpt-5.3": "gpt-5.3-codex",
|
||||
"gpt-image-2": "gpt-image-2",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
|
||||
@@ -812,16 +812,6 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
return openAIGPT54FallbackPricing
|
||||
}
|
||||
|
||||
if isOpenAIImageGenerationModel(model) {
|
||||
for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} {
|
||||
if pricing, ok := s.pricingData[candidate]; ok {
|
||||
logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate)
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
|
||||
@@ -128,21 +128,6 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t
|
||||
require.Zero(t, got.LongContextInputTokenThreshold)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) {
|
||||
imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3}
|
||||
textPricing := &LiteLLMModelPricing{InputCostPerToken: 9}
|
||||
|
||||
svc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-image-2": imagePricing,
|
||||
"gpt-5.4": textPricing,
|
||||
},
|
||||
}
|
||||
|
||||
got := svc.GetModelPricing("gpt-image-3")
|
||||
require.Same(t, imagePricing, got)
|
||||
}
|
||||
|
||||
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"gpt-5.4": map[string]any{
|
||||
|
||||
@@ -73,6 +73,9 @@ func (m *sessionWindowMockRepo) GetByCRSAccountID(context.Context, string) (*Acc
|
||||
func (m *sessionWindowMockRepo) FindByExtraField(context.Context, string, any) ([]Account, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) CountByTLSFingerprintProfile(context.Context) (map[int64]int, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (m *sessionWindowMockRepo) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
@@ -546,8 +546,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
// channelMonitorIntervalMin / channelMonitorIntervalMax bound the default interval
|
||||
// (mirrors the monitor-level constraint but lives here so setting_service stays decoupled).
|
||||
const (
|
||||
channelMonitorIntervalMin = 15
|
||||
channelMonitorIntervalMax = 3600
|
||||
channelMonitorIntervalMin = 15
|
||||
channelMonitorIntervalMax = 3600
|
||||
channelMonitorIntervalFallback = 60
|
||||
)
|
||||
|
||||
@@ -578,8 +578,8 @@ func clampChannelMonitorInterval(v int) int {
|
||||
// ChannelMonitorRuntime is the lightweight view of the channel monitor feature
|
||||
// consumed by the runner and user-facing handlers.
|
||||
type ChannelMonitorRuntime struct {
|
||||
Enabled bool
|
||||
DefaultIntervalSeconds int
|
||||
Enabled bool
|
||||
DefaultIntervalSeconds int
|
||||
}
|
||||
|
||||
// GetChannelMonitorRuntime reads the channel monitor feature flags directly from
|
||||
@@ -628,56 +628,76 @@ func (s *SettingService) SetVersion(version string) {
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection
|
||||
// This implements the web.PublicSettingsProvider interface
|
||||
// PublicSettingsInjectionPayload is the JSON shape embedded into HTML as
|
||||
// `window.__APP_CONFIG__` so the frontend can hydrate feature flags & site
|
||||
// config before the first XHR finishes.
|
||||
//
|
||||
// INVARIANT: every `json` tag here MUST also exist on handler/dto.PublicSettings.
|
||||
// If you forget a feature-flag field here, the frontend's
|
||||
// `cachedPublicSettings.xxx_enabled` will be `undefined` on refresh until the
|
||||
// async `/api/v1/settings/public` call returns — which causes opt-in menus
|
||||
// (strict `=== true`) to flicker off/on. See
|
||||
// frontend/src/utils/featureFlags.ts for the matching registry.
|
||||
//
|
||||
// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
|
||||
// drift automatically (see setting_service_injection_test.go).
|
||||
type PublicSettingsInjectionPayload struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
Version string `json:"version"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
|
||||
// Feature flags — MUST match the opt-in/opt-out registry in
|
||||
// frontend/src/utils/featureFlags.ts. Missing a field here is the bug
|
||||
// that hid the "可用渠道" menu on page refresh.
|
||||
ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
|
||||
// This implements the web.PublicSettingsProvider interface.
|
||||
func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
|
||||
settings, err := s.GetPublicSettings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return a struct that matches the frontend's expected format
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
|
||||
TableDefaultPageSize int `json:"table_default_page_size"`
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
|
||||
WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
Version string `json:"version,omitempty"`
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
}{
|
||||
return &PublicSettingsInjectionPayload{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
@@ -706,17 +726,20 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
|
||||
WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
|
||||
WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
Version: s.version,
|
||||
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
|
||||
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
|
||||
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
|
||||
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
|
||||
ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -126,8 +126,8 @@ type SystemSettings struct {
|
||||
OpsMetricsIntervalSeconds int
|
||||
|
||||
// Channel Monitor feature
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
|
||||
|
||||
// Available Channels feature (user-facing aggregate view)
|
||||
AvailableChannelsEnabled bool `json:"available_channels_enabled"`
|
||||
|
||||
@@ -122,8 +122,8 @@ func TestShouldClearStickySession(t *testing.T) {
|
||||
{
|
||||
name: "overloaded account",
|
||||
account: &Account{
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
OverloadUntil: &future,
|
||||
},
|
||||
requestedModel: "",
|
||||
|
||||
Reference in New Issue
Block a user