feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
soraClient SoraClient
mediaStorage *SoraMediaStorage
rateLimitService *RateLimitService
httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
cfg *config.Config
}
@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
soraClient: soraClient,
mediaStorage: mediaStorage,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
cfg: cfg,
}
}
@@ -115,6 +116,15 @@ func NewSoraGatewayService(
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
// apikey 类型账号HTTP 透传到上游,不走 SoraSDKClient
if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
if s.httpUpstream == nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
}
return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
}
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
taskID := ""
var err error
videoCount := parseSoraVideoCount(reqBody)
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
VideoCount: videoCount,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
}
// 直调路径(/sora/v1/chat/completions保持纯透传不执行本地/S3 媒体落盘。
// 媒体存储由客户端 API 路径(/api/v1/sora/generate的异步流程负责。
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
if storeErr != nil {
// 存储失败时降级使用原始 URL不中断用户请求
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
} else {
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
}
}
func parseSoraVideoCount(body map[string]any) int {
if body == nil {
return 1
}
keys := []string{"video_count", "videos", "n_variants"}
for _, key := range keys {
count := parseIntWithDefault(body, key, 0)
if count > 0 {
return clampInt(count, 1, 3)
}
}
return 1
}
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
return def
}
func parseIntWithDefault(body map[string]any, key string, def int) int {
if body == nil {
return def
}
val, ok := body[key]
if !ok {
return def
}
switch typed := val.(type) {
case int:
return typed
case int32:
return int(typed)
case int64:
return int(typed)
case float64:
return int(typed)
case string:
parsed, err := strconv.Atoi(strings.TrimSpace(typed))
if err == nil {
return parsed
}
}
return def
}
func clampInt(v, minVal, maxVal int) int {
if v < minVal {
return minVal
}
if v > maxVal {
return maxVal
}
return v
}
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
accountID := int64(0)
if account != nil {
accountID = account.ID
}
logger.LegacyPrintf(
"service.sora",
"[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
accountID,
model,
upstreamErr.StatusCode,
strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
strings.TrimSpace(upstreamErr.Message),
truncateForLog(upstreamErr.Body, 1024),
)
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}