Merge branch 'dev-release'
This commit is contained in:
@@ -24,6 +24,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
|
||||
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
|
||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
|
||||
}
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
// 检查嵌套的 response.model 字段
|
||||
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "response.model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
return line
|
||||
@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
return body
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
|
||||
Reference in New Issue
Block a user