150 lines
3.5 KiB
Go
150 lines
3.5 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"strings"
|
|
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type openAIImageOutputCounter struct {
|
|
seen map[string]struct{}
|
|
count int
|
|
maxDataCount int
|
|
}
|
|
|
|
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
|
|
return &openAIImageOutputCounter{seen: make(map[string]struct{})}
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) Count() int {
|
|
if c == nil {
|
|
return 0
|
|
}
|
|
if c.maxDataCount > c.count {
|
|
return c.maxDataCount
|
|
}
|
|
return c.count
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
|
|
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
|
|
return
|
|
}
|
|
c.addDataArray(gjson.GetBytes(body, "data"))
|
|
c.addOutputArray(gjson.GetBytes(body, "output"))
|
|
c.addOutputArray(gjson.GetBytes(body, "response.output"))
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
|
|
if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
|
|
return
|
|
}
|
|
root := gjson.ParseBytes(data)
|
|
c.addDataArray(root.Get("data"))
|
|
eventType := strings.TrimSpace(root.Get("type").String())
|
|
switch eventType {
|
|
case "response.output_item.done":
|
|
c.addImageOutputItem(root.Get("item"))
|
|
case "response.completed", "response.done":
|
|
c.addOutputArray(root.Get("response.output"))
|
|
case "image_generation.completed":
|
|
if item := root.Get("item"); item.Exists() {
|
|
c.addImageOutputItem(item)
|
|
return
|
|
}
|
|
if output := root.Get("output"); output.Exists() {
|
|
c.addImageOutputItem(output)
|
|
return
|
|
}
|
|
c.addImageOutputItem(root)
|
|
}
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) AddSSEBody(body string) {
|
|
if c == nil || strings.TrimSpace(body) == "" {
|
|
return
|
|
}
|
|
forEachOpenAISSEDataPayload(body, c.AddSSEData)
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
|
|
if !data.IsArray() {
|
|
return
|
|
}
|
|
count := len(data.Array())
|
|
if count > c.maxDataCount {
|
|
c.maxDataCount = count
|
|
}
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
|
|
if !output.IsArray() {
|
|
return
|
|
}
|
|
output.ForEach(func(_, item gjson.Result) bool {
|
|
c.addImageOutputItem(item)
|
|
return true
|
|
})
|
|
}
|
|
|
|
func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
|
|
if !item.Exists() || !item.IsObject() {
|
|
return
|
|
}
|
|
itemType := strings.TrimSpace(item.Get("type").String())
|
|
if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
|
|
return
|
|
}
|
|
if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
|
|
return
|
|
}
|
|
result := strings.TrimSpace(item.Get("result").String())
|
|
if result == "" {
|
|
result = strings.TrimSpace(item.Get("b64_json").String())
|
|
}
|
|
if result == "" {
|
|
result = strings.TrimSpace(item.Get("url").String())
|
|
}
|
|
if result == "" && itemType != "image_generation.completed" {
|
|
return
|
|
}
|
|
key := strings.TrimSpace(item.Get("id").String())
|
|
if key == "" {
|
|
key = strings.TrimSpace(item.Get("call_id").String())
|
|
}
|
|
if key == "" {
|
|
key = hashOpenAIImageOutputResult(result)
|
|
}
|
|
if key == "" {
|
|
return
|
|
}
|
|
if _, exists := c.seen[key]; exists {
|
|
return
|
|
}
|
|
c.seen[key] = struct{}{}
|
|
c.count++
|
|
}
|
|
|
|
func hashOpenAIImageOutputResult(result string) string {
|
|
result = strings.TrimSpace(result)
|
|
if result == "" {
|
|
return ""
|
|
}
|
|
sum := sha256.Sum256([]byte(result))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
|
|
counter := newOpenAIImageOutputCounter()
|
|
counter.AddJSONResponse(body)
|
|
return counter.Count()
|
|
}
|
|
|
|
func countOpenAIImageOutputsFromSSEBody(body string) int {
|
|
counter := newOpenAIImageOutputCounter()
|
|
counter.AddSSEBody(body)
|
|
return counter.Count()
|
|
}
|