221 lines
6.2 KiB
Go
221 lines
6.2 KiB
Go
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
const (
|
|
openAIResponsesEndpoint = "/v1/responses"
|
|
openAIResponsesCompactEndpoint = "/v1/responses/compact"
|
|
imageGenerationPermissionMessage = "Image generation is not enabled for this group"
|
|
)
|
|
|
|
// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups.
|
|
func ImageGenerationPermissionMessage() string {
|
|
return imageGenerationPermissionMessage
|
|
}
|
|
|
|
// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present.
|
|
func GroupAllowsImageGeneration(group *Group) bool {
|
|
return group == nil || group.AllowImageGeneration
|
|
}
|
|
|
|
// IsImageGenerationIntent classifies requests that can produce generated images.
|
|
func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool {
|
|
if IsImageGenerationEndpoint(endpoint) {
|
|
return true
|
|
}
|
|
if isOpenAIImageGenerationModel(requestedModel) {
|
|
return true
|
|
}
|
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
|
return false
|
|
}
|
|
if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) {
|
|
return true
|
|
}
|
|
if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) {
|
|
return true
|
|
}
|
|
return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice"))
|
|
}
|
|
|
|
// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation.
|
|
func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool {
|
|
if IsImageGenerationEndpoint(endpoint) {
|
|
return true
|
|
}
|
|
if isOpenAIImageGenerationModel(requestedModel) {
|
|
return true
|
|
}
|
|
if reqBody == nil {
|
|
return false
|
|
}
|
|
if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) {
|
|
return true
|
|
}
|
|
if hasOpenAIImageGenerationTool(reqBody) {
|
|
return true
|
|
}
|
|
return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"])
|
|
}
|
|
|
|
// IsImageGenerationEndpoint identifies dedicated generated-image endpoints.
|
|
func IsImageGenerationEndpoint(endpoint string) bool {
|
|
switch normalizeImageGenerationEndpoint(endpoint) {
|
|
case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func normalizeImageGenerationEndpoint(endpoint string) string {
|
|
endpoint = strings.TrimSpace(strings.ToLower(endpoint))
|
|
if endpoint == "" {
|
|
return ""
|
|
}
|
|
endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com")
|
|
if idx := strings.IndexByte(endpoint, '?'); idx >= 0 {
|
|
endpoint = endpoint[:idx]
|
|
}
|
|
return strings.TrimRight(endpoint, "/")
|
|
}
|
|
|
|
func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool {
|
|
if !tools.IsArray() {
|
|
return false
|
|
}
|
|
found := false
|
|
tools.ForEach(func(_, item gjson.Result) bool {
|
|
if strings.TrimSpace(item.Get("type").String()) == "image_generation" {
|
|
found = true
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return found
|
|
}
|
|
|
|
func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool {
|
|
if !choice.Exists() {
|
|
return false
|
|
}
|
|
if choice.Type == gjson.String {
|
|
return strings.TrimSpace(choice.String()) == "image_generation"
|
|
}
|
|
if !choice.IsObject() {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(choice.Get("type").String()) == "image_generation" {
|
|
return true
|
|
}
|
|
if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" {
|
|
return true
|
|
}
|
|
if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool {
|
|
switch v := choice.(type) {
|
|
case string:
|
|
return strings.TrimSpace(v) == "image_generation"
|
|
case map[string]any:
|
|
if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" {
|
|
return true
|
|
}
|
|
if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" {
|
|
return true
|
|
}
|
|
if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
v, exists := c.Get("api_key")
|
|
if !exists {
|
|
return nil
|
|
}
|
|
apiKey, _ := v.(*APIKey)
|
|
return apiKey
|
|
}
|
|
|
|
func apiKeyGroup(apiKey *APIKey) *Group {
|
|
if apiKey == nil {
|
|
return nil
|
|
}
|
|
return apiKey.Group
|
|
}
|
|
|
|
func cloneRequestMapForImageIntent(body []byte) map[string]any {
|
|
if len(body) == 0 {
|
|
return nil
|
|
}
|
|
var out map[string]any
|
|
if err := json.Unmarshal(body, &out); err != nil {
|
|
return nil
|
|
}
|
|
return out
|
|
}
|
|
|
|
func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
|
|
imageModel := ""
|
|
imageSize := ""
|
|
hasImageTool := false
|
|
if reqBody != nil {
|
|
rawTools, _ := reqBody["tools"].([]any)
|
|
for _, rawTool := range rawTools {
|
|
toolMap, ok := rawTool.(map[string]any)
|
|
if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
|
|
continue
|
|
}
|
|
hasImageTool = true
|
|
imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"]))
|
|
imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"]))
|
|
break
|
|
}
|
|
if imageSize == "" {
|
|
imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"]))
|
|
}
|
|
}
|
|
if imageModel == "" && reqBody != nil {
|
|
bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
|
|
if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool {
|
|
imageModel = bodyModel
|
|
}
|
|
}
|
|
if imageModel == "" && hasImageTool {
|
|
imageModel = "gpt-image-2"
|
|
}
|
|
if imageModel == "" {
|
|
imageModel = strings.TrimSpace(fallbackModel)
|
|
}
|
|
sizeTier := normalizeOpenAIImageSizeTier(imageSize)
|
|
return imageModel, sizeTier, nil
|
|
}
|
|
|
|
func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
|
|
reqBody := cloneRequestMapForImageIntent(body)
|
|
return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
|
|
}
|
|
|
|
func isOpenAIImageBillingModelAlias(model string) bool {
|
|
normalized := strings.ToLower(strings.TrimSpace(model))
|
|
if normalized == "" {
|
|
return false
|
|
}
|
|
return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image")
|
|
}
|