fix(provider): preserve requested model in antigravity and sora
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: billingModel, // 使用映射模型用于计费和日志
|
Model: originalModel,
|
||||||
|
UpstreamModel: billingModel,
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -2435,7 +2436,8 @@ handleSuccess:
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: billingModel,
|
Model: originalModel,
|
||||||
|
UpstreamModel: billingModel,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
|||||||
@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
|||||||
result, err := svc.Forward(context.Background(), c, account, body, false)
|
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, mappedModel, result.Model)
|
require.Equal(t, "claude-sonnet-4-5", result.Model)
|
||||||
|
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||||
@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
|||||||
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, mappedModel, result.Model)
|
require.Equal(t, "gemini-2.5-flash", result.Model)
|
||||||
|
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||||
@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur
|
|||||||
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, mappedModel, result.Model)
|
require.Equal(t, originalModel, result.Model)
|
||||||
|
require.Equal(t, mappedModel, result.UpstreamModel)
|
||||||
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||||
|
|
||||||
firstReq := string(upstream.requestBodies[0])
|
firstReq := string(upstream.requestBodies[0])
|
||||||
|
|||||||
@@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
|
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
|
||||||
return nil, errors.New("model is required")
|
return nil, errors.New("model is required")
|
||||||
}
|
}
|
||||||
|
originalModel := reqModel
|
||||||
|
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
|
var upstreamModel string
|
||||||
if mappedModel != "" && mappedModel != reqModel {
|
if mappedModel != "" && mappedModel != reqModel {
|
||||||
reqModel = mappedModel
|
reqModel = mappedModel
|
||||||
|
upstreamModel = mappedModel
|
||||||
}
|
}
|
||||||
|
|
||||||
modelCfg, ok := GetSoraModelConfig(reqModel)
|
modelCfg, ok := GetSoraModelConfig(reqModel)
|
||||||
@@ -213,13 +216,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
|
||||||
}
|
}
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: "",
|
RequestID: "",
|
||||||
Model: reqModel,
|
Model: originalModel,
|
||||||
Stream: clientStream,
|
UpstreamModel: upstreamModel,
|
||||||
Duration: time.Since(startTime),
|
Stream: clientStream,
|
||||||
FirstTokenMs: firstTokenMs,
|
Duration: time.Since(startTime),
|
||||||
Usage: ClaudeUsage{},
|
FirstTokenMs: firstTokenMs,
|
||||||
MediaType: "prompt",
|
Usage: ClaudeUsage{},
|
||||||
|
MediaType: "prompt",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,13 +273,14 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: "",
|
RequestID: "",
|
||||||
Model: reqModel,
|
Model: originalModel,
|
||||||
Stream: clientStream,
|
UpstreamModel: upstreamModel,
|
||||||
Duration: time.Since(startTime),
|
Stream: clientStream,
|
||||||
FirstTokenMs: firstTokenMs,
|
Duration: time.Since(startTime),
|
||||||
Usage: ClaudeUsage{},
|
FirstTokenMs: firstTokenMs,
|
||||||
MediaType: "prompt",
|
Usage: ClaudeUsage{},
|
||||||
|
MediaType: "prompt",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
|
||||||
@@ -419,16 +424,17 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: taskID,
|
RequestID: taskID,
|
||||||
Model: reqModel,
|
Model: originalModel,
|
||||||
Stream: clientStream,
|
UpstreamModel: upstreamModel,
|
||||||
Duration: time.Since(startTime),
|
Stream: clientStream,
|
||||||
FirstTokenMs: firstTokenMs,
|
Duration: time.Since(startTime),
|
||||||
Usage: ClaudeUsage{},
|
FirstTokenMs: firstTokenMs,
|
||||||
MediaType: mediaType,
|
Usage: ClaudeUsage{},
|
||||||
MediaURL: firstMediaURL(finalURLs),
|
MediaType: mediaType,
|
||||||
ImageCount: imageCount,
|
MediaURL: firstMediaURL(finalURLs),
|
||||||
ImageSize: imageSize,
|
ImageCount: imageCount,
|
||||||
|
ImageSize: imageSize,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
|||||||
ID: 1,
|
ID: 1,
|
||||||
Platform: PlatformSora,
|
Platform: PlatformSora,
|
||||||
Status: StatusActive,
|
Status: StatusActive,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
|
||||||
|
|
||||||
@@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.Equal(t, "prompt", result.MediaType)
|
require.Equal(t, "prompt", result.MediaType)
|
||||||
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
require.Equal(t, "prompt-enhance-short-10s", result.Model)
|
||||||
|
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user