merge upstream main
This commit is contained in:
@@ -26,7 +26,7 @@ import (
|
||||
|
||||
const (
|
||||
antigravityStickySessionTTL = time.Hour
|
||||
antigravityDefaultMaxRetries = 5
|
||||
antigravityDefaultMaxRetries = 3
|
||||
antigravityRetryBaseDelay = 1 * time.Second
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
@@ -52,11 +52,11 @@ type antigravityRetryLoopParams struct {
|
||||
action string
|
||||
body []byte
|
||||
quotaScope AntigravityQuotaScope
|
||||
maxRetries int
|
||||
c *gin.Context
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope)
|
||||
maxRetries int // 可选,0 表示使用平台级默认值
|
||||
}
|
||||
|
||||
// antigravityRetryLoopResult 重试循环的结果
|
||||
@@ -82,9 +82,10 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = baseURLs
|
||||
}
|
||||
|
||||
maxRetries := p.maxRetries
|
||||
if maxRetries <= 0 {
|
||||
maxRetries = antigravityMaxRetries()
|
||||
maxRetries = antigravityDefaultMaxRetries
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
@@ -161,7 +162,7 @@ urlFallbackLoop:
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
|
||||
// 账户/模型配额限流,按最大重试次数做指数退避
|
||||
// 账户/模型配额限流,重试 3 次(指数退避)
|
||||
if attempt < maxRetries {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
@@ -1044,7 +1045,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: billingModel,
|
||||
Model: billingModel, // 计费模型(可按映射模型覆盖)
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -1729,7 +1730,6 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
||||
}
|
||||
return time.Duration(seconds) * time.Second, true
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||
if statusCode == 429 {
|
||||
@@ -1742,9 +1742,6 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
|
||||
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
|
||||
}
|
||||
defaultDur := time.Duration(fallbackMinutes) * time.Minute
|
||||
if override, ok := antigravityFallbackCooldownSeconds(); ok {
|
||||
defaultDur = override
|
||||
}
|
||||
ra := time.Now().Add(defaultDur)
|
||||
if useScopeLimit {
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||||
@@ -2185,6 +2182,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
|
||||
return result, existingParts, setParts
|
||||
}
|
||||
|
||||
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
|
||||
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
|
||||
// 保持原始顺序,只合并连续的普通 text parts
|
||||
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
|
||||
if len(collectedParts) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
result, _, setParts := getOrCreateGeminiParts(response)
|
||||
|
||||
// 合并策略:
|
||||
// 1. 保持原始顺序
|
||||
// 2. 连续的普通 text parts 合并为一个
|
||||
// 3. thinking、functionCall、inlineData 等保持原样
|
||||
var mergedParts []any
|
||||
var textBuffer strings.Builder
|
||||
|
||||
flushTextBuffer := func() {
|
||||
if textBuffer.Len() > 0 {
|
||||
mergedParts = append(mergedParts, map[string]any{
|
||||
"text": textBuffer.String(),
|
||||
})
|
||||
textBuffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range collectedParts {
|
||||
// 检查是否是普通 text part
|
||||
if text, ok := part["text"].(string); ok {
|
||||
// 检查是否有 thought 标记
|
||||
if thought, _ := part["thought"].(bool); thought {
|
||||
// thinking part,先刷新 text buffer,然后保留原样
|
||||
flushTextBuffer()
|
||||
mergedParts = append(mergedParts, part)
|
||||
} else {
|
||||
// 普通 text,累积到 buffer
|
||||
_, _ = textBuffer.WriteString(text)
|
||||
}
|
||||
} else {
|
||||
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
|
||||
flushTextBuffer()
|
||||
mergedParts = append(mergedParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新剩余的 text
|
||||
flushTextBuffer()
|
||||
|
||||
setParts(mergedParts)
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
||||
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
||||
if len(imageParts) == 0 {
|
||||
@@ -2372,8 +2421,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
var firstTokenMs *int
|
||||
var last map[string]any
|
||||
var lastWithParts map[string]any
|
||||
var collectedImageParts []map[string]any // 收集所有包含图片的 parts
|
||||
var collectedTextParts []string // 收集所有文本片段
|
||||
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -2468,18 +2516,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
last = parsed
|
||||
|
||||
// 保留最后一个有 parts 的响应
|
||||
// 保留最后一个有 parts 的响应,并收集所有 parts
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
// 收集包含图片和文本的 parts
|
||||
for _, part := range parts {
|
||||
if _, ok := part["inlineData"].(map[string]any); ok {
|
||||
collectedImageParts = append(collectedImageParts, part)
|
||||
}
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
collectedTextParts = append(collectedTextParts, text)
|
||||
}
|
||||
}
|
||||
|
||||
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
|
||||
collectedParts = append(collectedParts, parts...)
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -2502,32 +2544,13 @@ returnResponse:
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||||
}
|
||||
|
||||
// 如果收集到了图片 parts,需要合并到最终响应中
|
||||
if len(collectedImageParts) > 0 {
|
||||
finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts)
|
||||
}
|
||||
|
||||
// 如果收集到了文本,需要合并到最终响应中
|
||||
if len(collectedTextParts) > 0 {
|
||||
finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts)
|
||||
}
|
||||
|
||||
geminiPayload := finalResponse
|
||||
if _, ok := finalResponse["response"]; !ok {
|
||||
wrapped := map[string]any{
|
||||
"response": finalResponse,
|
||||
}
|
||||
if respID, ok := finalResponse["responseId"]; ok {
|
||||
wrapped["responseId"] = respID
|
||||
}
|
||||
if modelVersion, ok := finalResponse["modelVersion"]; ok {
|
||||
wrapped["modelVersion"] = modelVersion
|
||||
}
|
||||
geminiPayload = wrapped
|
||||
// 将收集的所有 parts 合并到最终响应中
|
||||
if len(collectedParts) > 0 {
|
||||
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
|
||||
}
|
||||
|
||||
// 序列化为 JSON(Gemini 格式)
|
||||
geminiBody, err := json.Marshal(geminiPayload)
|
||||
geminiBody, err := json.Marshal(finalResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal gemini response: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user