fix(sora): 修复流式重写与计费问题
This commit is contained in:
@@ -928,6 +928,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
|
viper.SetDefault("sora2api.admin_token_ttl_seconds", 900)
|
||||||
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
|
viper.SetDefault("sora2api.admin_timeout_seconds", 10)
|
||||||
viper.SetDefault("sora2api.token_import_mode", "at")
|
viper.SetDefault("sora2api.token_import_mode", "at")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
@@ -1263,20 +1264,6 @@ func (c *Config) Validate() error {
|
|||||||
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
|
if err := ValidateAbsoluteHTTPURL(c.Sora2API.BaseURL); err != nil {
|
||||||
return fmt.Errorf("sora2api.base_url invalid: %w", err)
|
return fmt.Errorf("sora2api.base_url invalid: %w", err)
|
||||||
}
|
}
|
||||||
warnIfInsecureURL("sora2api.base_url", c.Sora2API.BaseURL)
|
|
||||||
}
|
|
||||||
if mode := strings.TrimSpace(strings.ToLower(c.Sora2API.TokenImportMode)); mode != "" {
|
|
||||||
switch mode {
|
|
||||||
case "at", "offline":
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("sora2api.token_import_mode must be one of: at/offline")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Sora2API.AdminTokenTTLSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora2api.admin_token_ttl_seconds must be non-negative")
|
|
||||||
}
|
|
||||||
if c.Sora2API.AdminTimeoutSeconds < 0 {
|
|
||||||
return fmt.Errorf("sora2api.admin_timeout_seconds must be non-negative")
|
|
||||||
}
|
}
|
||||||
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
if c.Ops.MetricsCollectorCache.TTL < 0 {
|
||||||
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative")
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type SoraGatewayHandler struct {
|
|||||||
streamMode string
|
streamMode string
|
||||||
sora2apiBaseURL string
|
sora2apiBaseURL string
|
||||||
soraMediaSigningKey string
|
soraMediaSigningKey string
|
||||||
|
mediaClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||||
@@ -61,6 +62,10 @@ func NewSoraGatewayHandler(
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
|
baseURL = strings.TrimRight(strings.TrimSpace(cfg.Sora2API.BaseURL), "/")
|
||||||
}
|
}
|
||||||
|
mediaTimeout := 180 * time.Second
|
||||||
|
if cfg != nil && cfg.Gateway.SoraRequestTimeoutSeconds > 0 {
|
||||||
|
mediaTimeout = time.Duration(cfg.Gateway.SoraRequestTimeoutSeconds) * time.Second
|
||||||
|
}
|
||||||
return &SoraGatewayHandler{
|
return &SoraGatewayHandler{
|
||||||
gatewayService: gatewayService,
|
gatewayService: gatewayService,
|
||||||
soraGatewayService: soraGatewayService,
|
soraGatewayService: soraGatewayService,
|
||||||
@@ -70,6 +75,7 @@ func NewSoraGatewayHandler(
|
|||||||
streamMode: strings.ToLower(streamMode),
|
streamMode: strings.ToLower(streamMode),
|
||||||
sora2apiBaseURL: baseURL,
|
sora2apiBaseURL: baseURL,
|
||||||
soraMediaSigningKey: signKey,
|
soraMediaSigningKey: signKey,
|
||||||
|
mediaClient: &http.Client{Timeout: mediaTimeout},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,7 +463,11 @@ func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature boo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
client := h.mediaClient
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Status(http.StatusBadGateway)
|
c.Status(http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1565,7 +1565,7 @@ func itoa(v int) string {
|
|||||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||||
//
|
//
|
||||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value interface{}) ([]service.Account, error) {
|
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||||
accounts, err := r.client.Account.Query().
|
accounts, err := r.client.Account.Query().
|
||||||
Where(
|
Where(
|
||||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID in
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return nil, nil // 记录不存在
|
return nil, nil // 记录不存在
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ type AccountRepository interface {
|
|||||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||||
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
||||||
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
||||||
FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error)
|
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||||
Update(ctx context.Context, account *Account) error
|
Update(ctx context.Context, account *Account) error
|
||||||
Delete(ctx context.Context, id int64) error
|
Delete(ctx context.Context, id int64) error
|
||||||
|
|
||||||
|
|||||||
@@ -3465,7 +3465,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
|
|
||||||
// 根据请求类型选择计费方式
|
// 根据请求类型选择计费方式
|
||||||
if result.MediaType == "image" || result.MediaType == "video" || result.MediaType == "prompt" {
|
if result.MediaType == "image" || result.MediaType == "video" {
|
||||||
var soraConfig *SoraPriceConfig
|
var soraConfig *SoraPriceConfig
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
soraConfig = &SoraPriceConfig{
|
soraConfig = &SoraPriceConfig{
|
||||||
@@ -3480,6 +3480,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
} else {
|
} else {
|
||||||
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier)
|
||||||
}
|
}
|
||||||
|
} else if result.MediaType == "prompt" {
|
||||||
|
cost = &CostBreakdown{}
|
||||||
} else if result.ImageCount > 0 {
|
} else if result.ImageCount > 0 {
|
||||||
// 图片生成计费
|
// 图片生成计费
|
||||||
var groupConfig *ImagePriceConfig
|
var groupConfig *ImagePriceConfig
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ type Sora2APIService struct {
|
|||||||
adminUsername string
|
adminUsername string
|
||||||
adminPassword string
|
adminPassword string
|
||||||
adminTokenTTL time.Duration
|
adminTokenTTL time.Duration
|
||||||
adminTimeout time.Duration
|
|
||||||
tokenImportMode string
|
tokenImportMode string
|
||||||
|
|
||||||
client *http.Client
|
client *http.Client
|
||||||
@@ -73,7 +72,6 @@ type Sora2APIService struct {
|
|||||||
adminMu sync.Mutex
|
adminMu sync.Mutex
|
||||||
|
|
||||||
modelCache []Sora2APIModel
|
modelCache []Sora2APIModel
|
||||||
modelCacheAt time.Time
|
|
||||||
modelMu sync.RWMutex
|
modelMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,7 +94,6 @@ func NewSora2APIService(cfg *config.Config) *Sora2APIService {
|
|||||||
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
|
adminUsername: strings.TrimSpace(cfg.Sora2API.AdminUsername),
|
||||||
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
|
adminPassword: strings.TrimSpace(cfg.Sora2API.AdminPassword),
|
||||||
adminTokenTTL: adminTTL,
|
adminTokenTTL: adminTTL,
|
||||||
adminTimeout: adminTimeout,
|
|
||||||
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
|
tokenImportMode: strings.ToLower(strings.TrimSpace(cfg.Sora2API.TokenImportMode)),
|
||||||
client: &http.Client{},
|
client: &http.Client{},
|
||||||
adminClient: &http.Client{Timeout: adminTimeout},
|
adminClient: &http.Client{Timeout: adminTimeout},
|
||||||
@@ -176,7 +173,6 @@ func (s *Sora2APIService) ListModels(ctx context.Context) ([]Sora2APIModel, erro
|
|||||||
|
|
||||||
s.modelMu.Lock()
|
s.modelMu.Lock()
|
||||||
s.modelCache = models
|
s.modelCache = models
|
||||||
s.modelCacheAt = time.Now()
|
|
||||||
s.modelMu.Unlock()
|
s.modelMu.Unlock()
|
||||||
|
|
||||||
return models, nil
|
return models, nil
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ var soraSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
|||||||
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
var soraImageMarkdownRe = regexp.MustCompile(`!\[[^\]]*\]\(([^)]+)\)`)
|
||||||
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
var soraVideoHTMLRe = regexp.MustCompile(`(?i)<video[^>]+src=['"]([^'"]+)['"]`)
|
||||||
|
|
||||||
|
const soraRewriteBufferLimit = 2048
|
||||||
|
|
||||||
var soraImageSizeMap = map[string]string{
|
var soraImageSizeMap = map[string]string{
|
||||||
"gpt-image": "360",
|
"gpt-image": "360",
|
||||||
"gpt-image-landscape": "540",
|
"gpt-image-landscape": "540",
|
||||||
@@ -30,7 +32,6 @@ var soraImageSizeMap = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
type soraStreamingResult struct {
|
type soraStreamingResult struct {
|
||||||
content string
|
|
||||||
mediaType string
|
mediaType string
|
||||||
mediaURLs []string
|
mediaURLs []string
|
||||||
imageCount int
|
imageCount int
|
||||||
@@ -307,6 +308,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
contentBuilder := strings.Builder{}
|
contentBuilder := strings.Builder{}
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
var upstreamError error
|
var upstreamError error
|
||||||
|
rewriteBuffer := ""
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
maxLineSize := defaultMaxLineSize
|
maxLineSize := defaultMaxLineSize
|
||||||
@@ -333,12 +335,29 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
if soraSSEDataRe.MatchString(line) {
|
if soraSSEDataRe.MatchString(line) {
|
||||||
data := soraSSEDataRe.ReplaceAllString(line, "")
|
data := soraSSEDataRe.ReplaceAllString(line, "")
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
|
if rewriteBuffer != "" {
|
||||||
|
flushLine, flushContent, err := s.flushSoraRewriteBuffer(rewriteBuffer, originalModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if flushLine != "" {
|
||||||
|
if flushContent != "" {
|
||||||
|
if _, err := contentBuilder.WriteString(flushContent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := sendLine(flushLine); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriteBuffer = ""
|
||||||
|
}
|
||||||
if err := sendLine("data: [DONE]"); err != nil {
|
if err := sendLine("data: [DONE]"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel)
|
updatedLine, contentDelta, errEvent := s.processSoraSSEData(data, originalModel, &rewriteBuffer)
|
||||||
if errEvent != nil && upstreamError == nil {
|
if errEvent != nil && upstreamError == nil {
|
||||||
upstreamError = errEvent
|
upstreamError = errEvent
|
||||||
}
|
}
|
||||||
@@ -347,7 +366,9 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
}
|
}
|
||||||
contentBuilder.WriteString(contentDelta)
|
if _, err := contentBuilder.WriteString(contentDelta); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := sendLine(updatedLine); err != nil {
|
if err := sendLine(updatedLine); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -417,7 +438,6 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &soraStreamingResult{
|
return &soraStreamingResult{
|
||||||
content: content,
|
|
||||||
mediaType: mediaType,
|
mediaType: mediaType,
|
||||||
mediaURLs: mediaURLs,
|
mediaURLs: mediaURLs,
|
||||||
imageCount: imageCount,
|
imageCount: imageCount,
|
||||||
@@ -426,7 +446,7 @@ func (s *SoraGatewayService) handleStreamingResponse(ctx context.Context, resp *
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string) (string, string, error) {
|
func (s *SoraGatewayService) processSoraSSEData(data string, originalModel string, rewriteBuffer *string) (string, string, error) {
|
||||||
if strings.TrimSpace(data) == "" {
|
if strings.TrimSpace(data) == "" {
|
||||||
return "data: ", "", nil
|
return "data: ", "", nil
|
||||||
}
|
}
|
||||||
@@ -448,7 +468,12 @@ func (s *SoraGatewayService) processSoraSSEData(data string, originalModel strin
|
|||||||
|
|
||||||
contentDelta, updated := extractSoraContent(payload)
|
contentDelta, updated := extractSoraContent(payload)
|
||||||
if updated {
|
if updated {
|
||||||
rewritten := s.rewriteSoraContent(contentDelta)
|
var rewritten string
|
||||||
|
if rewriteBuffer != nil {
|
||||||
|
rewritten = s.rewriteSoraContentWithBuffer(contentDelta, rewriteBuffer)
|
||||||
|
} else {
|
||||||
|
rewritten = s.rewriteSoraContent(contentDelta)
|
||||||
|
}
|
||||||
if rewritten != contentDelta {
|
if rewritten != contentDelta {
|
||||||
applySoraContent(payload, rewritten)
|
applySoraContent(payload, rewritten)
|
||||||
contentDelta = rewritten
|
contentDelta = rewritten
|
||||||
@@ -504,6 +529,78 @@ func applySoraContent(payload map[string]any, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) rewriteSoraContentWithBuffer(contentDelta string, buffer *string) string {
|
||||||
|
if buffer == nil {
|
||||||
|
return s.rewriteSoraContent(contentDelta)
|
||||||
|
}
|
||||||
|
if contentDelta == "" && *buffer == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
combined := *buffer + contentDelta
|
||||||
|
rewritten := s.rewriteSoraContent(combined)
|
||||||
|
bufferStart := s.findSoraRewriteBufferStart(rewritten)
|
||||||
|
if bufferStart < 0 {
|
||||||
|
*buffer = ""
|
||||||
|
return rewritten
|
||||||
|
}
|
||||||
|
if len(rewritten)-bufferStart > soraRewriteBufferLimit {
|
||||||
|
bufferStart = len(rewritten) - soraRewriteBufferLimit
|
||||||
|
}
|
||||||
|
output := rewritten[:bufferStart]
|
||||||
|
*buffer = rewritten[bufferStart:]
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) findSoraRewriteBufferStart(content string) int {
|
||||||
|
minIndex := -1
|
||||||
|
start := 0
|
||||||
|
for {
|
||||||
|
idx := strings.Index(content[start:], "![")
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
idx += start
|
||||||
|
if !hasSoraImageMatchAt(content, idx) {
|
||||||
|
if minIndex == -1 || idx < minIndex {
|
||||||
|
minIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = idx + 2
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(content)
|
||||||
|
start = 0
|
||||||
|
for {
|
||||||
|
idx := strings.Index(lower[start:], "<video")
|
||||||
|
if idx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
idx += start
|
||||||
|
if !hasSoraVideoMatchAt(content, idx) {
|
||||||
|
if minIndex == -1 || idx < minIndex {
|
||||||
|
minIndex = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = idx + len("<video")
|
||||||
|
}
|
||||||
|
return minIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSoraImageMatchAt(content string, idx int) bool {
|
||||||
|
if idx < 0 || idx >= len(content) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loc := soraImageMarkdownRe.FindStringIndex(content[idx:])
|
||||||
|
return loc != nil && loc[0] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSoraVideoMatchAt(content string, idx int) bool {
|
||||||
|
if idx < 0 || idx >= len(content) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
loc := soraVideoHTMLRe.FindStringIndex(content[idx:])
|
||||||
|
return loc != nil && loc[0] == 0
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
||||||
if content == "" {
|
if content == "" {
|
||||||
return content
|
return content
|
||||||
@@ -533,6 +630,31 @@ func (s *SoraGatewayService) rewriteSoraContent(content string) string {
|
|||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SoraGatewayService) flushSoraRewriteBuffer(buffer string, originalModel string) (string, string, error) {
|
||||||
|
if buffer == "" {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
rewritten := s.rewriteSoraContent(buffer)
|
||||||
|
payload := map[string]any{
|
||||||
|
"choices": []any{
|
||||||
|
map[string]any{
|
||||||
|
"delta": map[string]any{
|
||||||
|
"content": rewritten,
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if originalModel != "" {
|
||||||
|
payload["model"] = originalModel
|
||||||
|
}
|
||||||
|
updatedData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return "data: " + string(updatedData), rewritten, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
func (s *SoraGatewayService) rewriteSoraURL(raw string) string {
|
||||||
raw = strings.TrimSpace(raw)
|
raw = strings.TrimSpace(raw)
|
||||||
if raw == "" {
|
if raw == "" {
|
||||||
|
|||||||
@@ -15,9 +15,15 @@ func SignSoraMediaURL(path string, query string, expires int64, key string) stri
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
mac := hmac.New(sha256.New, []byte(key))
|
mac := hmac.New(sha256.New, []byte(key))
|
||||||
mac.Write([]byte(buildSoraMediaSignPayload(path, query)))
|
if _, err := mac.Write([]byte(buildSoraMediaSignPayload(path, query))); err != nil {
|
||||||
mac.Write([]byte("|"))
|
return ""
|
||||||
mac.Write([]byte(strconv.FormatInt(expires, 10)))
|
}
|
||||||
|
if _, err := mac.Write([]byte("|")); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, err := mac.Write([]byte(strconv.FormatInt(expires, 10))); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return hex.EncodeToString(mac.Sum(nil))
|
return hex.EncodeToString(mac.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,9 @@ import (
|
|||||||
// 定期检查并刷新即将过期的token
|
// 定期检查并刷新即将过期的token
|
||||||
type TokenRefreshService struct {
|
type TokenRefreshService struct {
|
||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
|
|
||||||
refreshers []TokenRefresher
|
refreshers []TokenRefresher
|
||||||
cfg *config.TokenRefreshConfig
|
cfg *config.TokenRefreshConfig
|
||||||
cacheInvalidator TokenCacheInvalidator
|
cacheInvalidator TokenCacheInvalidator
|
||||||
soraSyncService *Sora2APISyncService
|
|
||||||
|
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
@@ -57,7 +55,6 @@ func NewTokenRefreshService(
|
|||||||
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
|
// 用于在 OpenAI Token 刷新时同步更新 sora_accounts 表
|
||||||
// 需要在 Start() 之前调用
|
// 需要在 Start() 之前调用
|
||||||
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||||
s.soraAccountRepo = repo
|
|
||||||
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
|
// 将 soraAccountRepo 注入到 OpenAITokenRefresher
|
||||||
for _, refresher := range s.refreshers {
|
for _, refresher := range s.refreshers {
|
||||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
||||||
@@ -69,7 +66,6 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
|||||||
// SetSoraSyncService 设置 Sora2API 同步服务
|
// SetSoraSyncService 设置 Sora2API 同步服务
|
||||||
// 需要在 Start() 之前调用
|
// 需要在 Start() 之前调用
|
||||||
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
|
func (s *TokenRefreshService) SetSoraSyncService(svc *Sora2APISyncService) {
|
||||||
s.soraSyncService = svc
|
|
||||||
for _, refresher := range s.refreshers {
|
for _, refresher := range s.refreshers {
|
||||||
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
if openaiRefresher, ok := refresher.(*OpenAITokenRefresher); ok {
|
||||||
openaiRefresher.SetSoraSyncService(svc)
|
openaiRefresher.SetSoraSyncService(svc)
|
||||||
|
|||||||
@@ -51,9 +51,7 @@ func ProvideTokenRefreshService(
|
|||||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
|
||||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||||
if soraSyncService != nil {
|
|
||||||
svc.SetSoraSyncService(soraSyncService)
|
svc.SetSoraSyncService(soraSyncService)
|
||||||
}
|
|
||||||
svc.Start()
|
svc.Start()
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
@@ -242,8 +240,6 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewAntigravityTokenProvider,
|
NewAntigravityTokenProvider,
|
||||||
NewOpenAITokenProvider,
|
NewOpenAITokenProvider,
|
||||||
NewClaudeTokenProvider,
|
NewClaudeTokenProvider,
|
||||||
NewSora2APIService,
|
|
||||||
NewSora2APISyncService,
|
|
||||||
NewAntigravityGatewayService,
|
NewAntigravityGatewayService,
|
||||||
ProvideRateLimitService,
|
ProvideRateLimitService,
|
||||||
NewAccountUsageService,
|
NewAccountUsageService,
|
||||||
|
|||||||
Reference in New Issue
Block a user