150 lines
4.6 KiB
Go
150 lines
4.6 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
|
||
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
|
||
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
|
||
// 支持流式和非流式响应的直接透传。
|
||
func (s *SoraGatewayService) forwardToUpstream(
|
||
ctx context.Context,
|
||
c *gin.Context,
|
||
account *Account,
|
||
body []byte,
|
||
clientStream bool,
|
||
startTime time.Time,
|
||
) (*ForwardResult, error) {
|
||
apiKey := account.GetCredential("api_key")
|
||
if apiKey == "" {
|
||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
|
||
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
|
||
}
|
||
|
||
baseURL := account.GetBaseURL()
|
||
if baseURL == "" {
|
||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
|
||
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
|
||
}
|
||
// 校验 scheme 合法性(仅允许 http/https)
|
||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
|
||
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
|
||
}
|
||
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
|
||
|
||
// 构建上游请求
|
||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||
if err != nil {
|
||
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
|
||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||
}
|
||
|
||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
// 透传客户端的部分请求头
|
||
for _, header := range []string{"Accept", "Accept-Encoding"} {
|
||
if v := c.GetHeader(header); v != "" {
|
||
upstreamReq.Header.Set(header, v)
|
||
}
|
||
}
|
||
|
||
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
|
||
|
||
// 获取代理 URL
|
||
proxyURL := ""
|
||
if account.ProxyID != nil && account.Proxy != nil {
|
||
proxyURL = account.Proxy.URL()
|
||
}
|
||
|
||
// 发送请求
|
||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||
if err != nil {
|
||
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: http.StatusBadGateway,
|
||
}
|
||
}
|
||
defer func() {
|
||
_ = resp.Body.Close()
|
||
}()
|
||
|
||
// 错误响应处理
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
|
||
|
||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||
return nil, &UpstreamFailoverError{
|
||
StatusCode: resp.StatusCode,
|
||
ResponseBody: respBody,
|
||
ResponseHeaders: resp.Header.Clone(),
|
||
}
|
||
}
|
||
|
||
// 非转移错误,直接透传给客户端
|
||
c.Status(resp.StatusCode)
|
||
for key, values := range resp.Header {
|
||
for _, v := range values {
|
||
c.Writer.Header().Add(key, v)
|
||
}
|
||
}
|
||
if _, err := c.Writer.Write(respBody); err != nil {
|
||
return nil, fmt.Errorf("write upstream error response: %w", err)
|
||
}
|
||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||
}
|
||
|
||
// 成功响应 — 直接透传
|
||
c.Status(resp.StatusCode)
|
||
for key, values := range resp.Header {
|
||
lower := strings.ToLower(key)
|
||
// 透传内容相关头部
|
||
if lower == "content-type" || lower == "transfer-encoding" ||
|
||
lower == "cache-control" || lower == "x-request-id" {
|
||
for _, v := range values {
|
||
c.Writer.Header().Add(key, v)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 流式复制响应体
|
||
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
|
||
buf := make([]byte, 4096)
|
||
for {
|
||
n, readErr := resp.Body.Read(buf)
|
||
if n > 0 {
|
||
if _, err := c.Writer.Write(buf[:n]); err != nil {
|
||
return nil, fmt.Errorf("stream upstream response write: %w", err)
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
if readErr != nil {
|
||
break
|
||
}
|
||
}
|
||
} else {
|
||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||
return nil, fmt.Errorf("copy upstream response: %w", err)
|
||
}
|
||
}
|
||
|
||
duration := time.Since(startTime)
|
||
return &ForwardResult{
|
||
RequestID: resp.Header.Get("x-request-id"),
|
||
Model: "", // 由调用方填充
|
||
Stream: clientStream,
|
||
Duration: duration,
|
||
}, nil
|
||
}
|