feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -0,0 +1,149 @@
package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions"
// 使用 account.GetCredential("api_key") 作为 Bearer Token。
// 支持流式和非流式响应的直接透传。
func (s *SoraGatewayService) forwardToUpstream(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
clientStream bool,
startTime time.Time,
) (*ForwardResult, error) {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
}
baseURL := account.GetBaseURL()
if baseURL == "" {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
}
// 校验 scheme 合法性(仅允许 http/https
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
}
upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
// 构建上游请求
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
return nil, fmt.Errorf("create upstream request: %w", err)
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
// 透传客户端的部分请求头
for _, header := range []string{"Accept", "Accept-Encoding"} {
if v := c.GetHeader(header); v != "" {
upstreamReq.Header.Set(header, v)
}
}
logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
// 获取代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
}
}
defer func() {
_ = resp.Body.Close()
}()
// 错误响应处理
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
ResponseHeaders: resp.Header.Clone(),
}
}
// 非转移错误,直接透传给客户端
c.Status(resp.StatusCode)
for key, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
if _, err := c.Writer.Write(respBody); err != nil {
return nil, fmt.Errorf("write upstream error response: %w", err)
}
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 成功响应 — 直接透传
c.Status(resp.StatusCode)
for key, values := range resp.Header {
lower := strings.ToLower(key)
// 透传内容相关头部
if lower == "content-type" || lower == "transfer-encoding" ||
lower == "cache-control" || lower == "x-request-id" {
for _, v := range values {
c.Writer.Header().Add(key, v)
}
}
}
// 流式复制响应体
if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, err := c.Writer.Write(buf[:n]); err != nil {
return nil, fmt.Errorf("stream upstream response write: %w", err)
}
flusher.Flush()
}
if readErr != nil {
break
}
}
} else {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
return nil, fmt.Errorf("copy upstream response: %w", err)
}
}
duration := time.Since(startTime)
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Model: "", // 由调用方填充
Stream: clientStream,
Duration: duration,
}, nil
}