feat(sync): full code sync from release
This commit is contained in:
149
backend/internal/service/sora_upstream_forwarder.go
Normal file
149
backend/internal/service/sora_upstream_forwarder.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
|
||||
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
|
||||
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
|
||||
// 支持流式和非流式响应的直接透传。
|
||||
func (s *SoraGatewayService) forwardToUpstream(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
clientStream bool,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
|
||||
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
|
||||
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
|
||||
}
|
||||
// 校验 scheme 合法性(仅允许 http/https)
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
|
||||
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
|
||||
}
|
||||
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
|
||||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||
}
|
||||
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 透传客户端的部分请求头
|
||||
for _, header := range []string{"Accept", "Accept-Encoding"} {
|
||||
if v := c.GetHeader(header); v != "" {
|
||||
upstreamReq.Header.Set(header, v)
|
||||
}
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
|
||||
|
||||
// 获取代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 错误响应处理
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
ResponseHeaders: resp.Header.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// 非转移错误,直接透传给客户端
|
||||
c.Status(resp.StatusCode)
|
||||
for key, values := range resp.Header {
|
||||
for _, v := range values {
|
||||
c.Writer.Header().Add(key, v)
|
||||
}
|
||||
}
|
||||
if _, err := c.Writer.Write(respBody); err != nil {
|
||||
return nil, fmt.Errorf("write upstream error response: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 成功响应 — 直接透传
|
||||
c.Status(resp.StatusCode)
|
||||
for key, values := range resp.Header {
|
||||
lower := strings.ToLower(key)
|
||||
// 透传内容相关头部
|
||||
if lower == "content-type" || lower == "transfer-encoding" ||
|
||||
lower == "cache-control" || lower == "x-request-id" {
|
||||
for _, v := range values {
|
||||
c.Writer.Header().Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 流式复制响应体
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, err := c.Writer.Write(buf[:n]); err != nil {
|
||||
return nil, fmt.Errorf("stream upstream response write: %w", err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
return nil, fmt.Errorf("copy upstream response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Model: "", // 由调用方填充
|
||||
Stream: clientStream,
|
||||
Duration: duration,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user