feat(sync): full code sync from release
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
|
||||
// SoraGatewayService handles forwarding requests to Sora upstream.
|
||||
type SoraGatewayService struct {
|
||||
soraClient SoraClient
|
||||
mediaStorage *SoraMediaStorage
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
|
||||
|
||||
func NewSoraGatewayService(
|
||||
soraClient SoraClient,
|
||||
mediaStorage *SoraMediaStorage,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayService {
|
||||
return &SoraGatewayService{
|
||||
soraClient: soraClient,
|
||||
mediaStorage: mediaStorage,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
@@ -115,6 +116,15 @@ func NewSoraGatewayService(
|
||||
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
|
||||
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
|
||||
if s.httpUpstream == nil {
|
||||
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
|
||||
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
|
||||
}
|
||||
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
|
||||
}
|
||||
|
||||
if s.soraClient == nil || !s.soraClient.Enabled() {
|
||||
if c != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
|
||||
taskID := ""
|
||||
var err error
|
||||
videoCount := parseSoraVideoCount(reqBody)
|
||||
switch modelCfg.Type {
|
||||
case "image":
|
||||
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
|
||||
@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
Frames: modelCfg.Frames,
|
||||
Model: modelCfg.Model,
|
||||
Size: modelCfg.Size,
|
||||
VideoCount: videoCount,
|
||||
MediaID: mediaID,
|
||||
RemixTargetID: remixTargetID,
|
||||
CameoIDs: extractSoraCameoIDs(reqBody),
|
||||
@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
||||
}
|
||||
}
|
||||
|
||||
// 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
|
||||
// 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
|
||||
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
|
||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||
if storeErr != nil {
|
||||
// 存储失败时降级使用原始 URL,不中断用户请求
|
||||
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
|
||||
} else {
|
||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||
}
|
||||
}
|
||||
if watermarkPostID != "" && watermarkOpts.DeletePost {
|
||||
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
|
||||
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
|
||||
@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func parseSoraVideoCount(body map[string]any) int {
|
||||
if body == nil {
|
||||
return 1
|
||||
}
|
||||
keys := []string{"video_count", "videos", "n_variants"}
|
||||
for _, key := range keys {
|
||||
count := parseIntWithDefault(body, key, 0)
|
||||
if count > 0 {
|
||||
return clampInt(count, 1, 3)
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
|
||||
if body == nil {
|
||||
return def
|
||||
@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
|
||||
return def
|
||||
}
|
||||
|
||||
func parseIntWithDefault(body map[string]any, key string, def int) int {
|
||||
if body == nil {
|
||||
return def
|
||||
}
|
||||
val, ok := body[key]
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
switch typed := val.(type) {
|
||||
case int:
|
||||
return typed
|
||||
case int32:
|
||||
return int(typed)
|
||||
case int64:
|
||||
return int(typed)
|
||||
case float64:
|
||||
return int(typed)
|
||||
case string:
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
|
||||
if err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func clampInt(v, minVal, maxVal int) int {
|
||||
if v < minVal {
|
||||
return minVal
|
||||
}
|
||||
if v > maxVal {
|
||||
return maxVal
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func extractSoraCameoIDs(body map[string]any) []string {
|
||||
if body == nil {
|
||||
return nil
|
||||
@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
|
||||
}
|
||||
var upstreamErr *SoraUpstreamError
|
||||
if errors.As(err, &upstreamErr) {
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"service.sora",
|
||||
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
|
||||
accountID,
|
||||
model,
|
||||
upstreamErr.StatusCode,
|
||||
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
|
||||
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
|
||||
strings.TrimSpace(upstreamErr.Message),
|
||||
truncateForLog(upstreamErr.Body, 1024),
|
||||
)
|
||||
if s.rateLimitService != nil && account != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user