Files
sub2api/backend/internal/service/sora_upstream_forwarder.go
2026-02-28 15:01:20 +08:00

150 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions"
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
// 支持流式和非流式响应的直接透传。
func (s *SoraGatewayService) forwardToUpstream(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
clientStream bool,
startTime time.Time,
) (*ForwardResult, error) {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
}
baseURL := account.GetBaseURL()
if baseURL == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
}
// 校验 scheme 合法性(仅允许 http/https
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
}
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
// 构建上游请求
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
return nil, fmt.Errorf("create upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
// 透传客户端的部分请求头
for _, header := range []string{"Accept", "Accept-Encoding"} {
if v := c.GetHeader(header); v != "" {
upstreamReq.Header.Set(header, v)
}
}
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
}
}
defer func() {
_ = resp.Body.Close()
}()
// 错误响应处理
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
// 非转移错误,直接透传给客户端
c.Status(resp.StatusCode)
for key, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
if _, err := c.Writer.Write(respBody); err != nil {
return nil, fmt.Errorf("write upstream error response: %w", err)
}
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 成功响应 — 直接透传
c.Status(resp.StatusCode)
for key, values := range resp.Header {
lower := strings.ToLower(key)
// 透传内容相关头部
if lower == "content-type" || lower == "transfer-encoding" ||
lower == "cache-control" || lower == "x-request-id" {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
}
// 流式复制响应体
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, err := c.Writer.Write(buf[:n]); err != nil {
return nil, fmt.Errorf("stream upstream response write: %w", err)
}
flusher.Flush()
}
if readErr != nil {
break
}
}
} else {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return nil, fmt.Errorf("copy upstream response: %w", err)
}
}
duration := time.Since(startTime)
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Model: "", // 由调用方填充
Stream: clientStream,
Duration: duration,
}, nil
}