first commit
This commit is contained in:
228
backend/internal/pkg/antigravity/claude_types.go
Normal file
228
backend/internal/pkg/antigravity/claude_types.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package antigravity
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Claude 请求/响应类型定义
|
||||
|
||||
// ClaudeRequest Claude Messages API 请求
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ClaudeMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
System json.RawMessage `json:"system,omitempty"` // string 或 []SystemBlock
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
Tools []ClaudeTool `json:"tools,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeMessage Claude 消息
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"` // user, assistant
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
// ThinkingConfig Thinking 配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"` // "enabled" or "disabled"
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
|
||||
}
|
||||
|
||||
// ClaudeMetadata 请求元数据
|
||||
type ClaudeMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeTool Claude 工具定义
|
||||
// 支持两种格式:
|
||||
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
|
||||
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
|
||||
type ClaudeTool struct {
|
||||
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"` // 标准格式使用
|
||||
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
|
||||
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
|
||||
}
|
||||
|
||||
// CustomToolSpec MCP custom 工具规格
|
||||
type CustomToolSpec struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
|
||||
type ClaudeCustomToolSpec = CustomToolSpec
|
||||
|
||||
// SystemBlock system prompt 数组形式的元素
|
||||
type SystemBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ContentBlock Claude 消息内容块(解析后)
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
// thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
// tool_result
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
// image
|
||||
Source *ImageSource `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
// ImageSource Claude 图片来源
|
||||
type ImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
MediaType string `json:"media_type"` // "image/png", "image/jpeg" 等
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// ClaudeResponse Claude Messages API 响应
|
||||
type ClaudeResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "message"
|
||||
Role string `json:"role"` // "assistant"
|
||||
Model string `json:"model"`
|
||||
Content []ClaudeContentItem `json:"content"`
|
||||
StopReason string `json:"stop_reason,omitempty"` // end_turn, tool_use, max_tokens
|
||||
StopSequence *string `json:"stop_sequence,omitempty"` // null 或具体值
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ClaudeContentItem Claude 响应内容项
|
||||
type ClaudeContentItem struct {
|
||||
Type string `json:"type"` // text, thinking, tool_use
|
||||
|
||||
// text
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// thinking
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsage Claude 用量统计
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeError Claude 错误响应
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// modelDef Antigravity 模型定义(内部使用)
|
||||
type modelDef struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
CreatedAt string // 仅 Claude API 格式使用
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Claude 模型
|
||||
var claudeModels = []modelDef{
|
||||
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
// ========== Claude API 格式 (/v1/models) ==========
|
||||
|
||||
// ClaudeModel Claude API 模型格式
|
||||
type ClaudeModel struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels 返回 Claude API 格式的模型列表(Claude + Gemini)
|
||||
func DefaultModels() []ClaudeModel {
|
||||
all := append(claudeModels, geminiModels...)
|
||||
result := make([]ClaudeModel, len(all))
|
||||
for i, m := range all {
|
||||
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
|
||||
|
||||
// GeminiModel Gemini v1beta 模型格式
|
||||
type GeminiModel struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModelsListResponse Gemini v1beta 模型列表响应
|
||||
type GeminiModelsListResponse struct {
|
||||
Models []GeminiModel `json:"models"`
|
||||
}
|
||||
|
||||
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
|
||||
|
||||
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
|
||||
func DefaultGeminiModels() []GeminiModel {
|
||||
result := make([]GeminiModel, len(geminiModels))
|
||||
for i, m := range geminiModels {
|
||||
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
|
||||
func FallbackGeminiModelsList() GeminiModelsListResponse {
|
||||
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
|
||||
}
|
||||
|
||||
// FallbackGeminiModel 返回单个模型信息(v1beta 格式)
|
||||
func FallbackGeminiModel(model string) GeminiModel {
|
||||
if model == "" {
|
||||
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
name := model
|
||||
if len(model) < 7 || model[:7] != "models/" {
|
||||
name = "models/" + model
|
||||
}
|
||||
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
474
backend/internal/pkg/antigravity/client.go
Normal file
474
backend/internal/pkg/antigravity/client.go
Normal file
@@ -0,0 +1,474 @@
|
||||
// Package antigravity provides a client for the Antigravity API.
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// resolveHost 从 URL 解析 host
|
||||
func resolveHost(urlStr string) string {
|
||||
parsed, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Host
|
||||
}
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
// 构建 URL,流式请求添加 ?alt=sse 参数
|
||||
apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action)
|
||||
isStream := action == "streamGenerateContent"
|
||||
if isStream {
|
||||
apiURL += "?alt=sse"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 基础 Headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
// Accept Header 根据请求类型设置
|
||||
if isStream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 显式设置 Host Header
|
||||
if host := resolveHost(apiURL); host != "" {
|
||||
req.Host = host
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
// 向后兼容:仅使用默认 BaseURL
|
||||
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
|
||||
return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body)
|
||||
}
|
||||
|
||||
// TokenResponse Google OAuth token 响应
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo Google 用户信息
|
||||
type UserInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
}
|
||||
|
||||
// LoadCodeAssistRequest loadCodeAssist 请求
|
||||
type LoadCodeAssistRequest struct {
|
||||
Metadata struct {
|
||||
IDEType string `json:"ideType"`
|
||||
} `json:"metadata"`
|
||||
}
|
||||
|
||||
// TierInfo 账户类型信息
|
||||
type TierInfo struct {
|
||||
ID string `json:"id"` // free-tier, g1-pro-tier, g1-ultra-tier
|
||||
Name string `json:"name"` // 显示名称
|
||||
Description string `json:"description"` // 描述
|
||||
}
|
||||
|
||||
// UnmarshalJSON supports both legacy string tiers and object tiers.
|
||||
func (t *TierInfo) UnmarshalJSON(data []byte) error {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var id string
|
||||
if err := json.Unmarshal(data, &id); err != nil {
|
||||
return err
|
||||
}
|
||||
t.ID = id
|
||||
return nil
|
||||
}
|
||||
type alias TierInfo
|
||||
var decoded alias
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = TierInfo(decoded)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IneligibleTier 不符合条件的层级信息
|
||||
type IneligibleTier struct {
|
||||
Tier *TierInfo `json:"tier,omitempty"`
|
||||
// ReasonCode 不符合条件的原因代码,如 INELIGIBLE_ACCOUNT
|
||||
ReasonCode string `json:"reasonCode,omitempty"`
|
||||
ReasonMessage string `json:"reasonMessage,omitempty"`
|
||||
}
|
||||
|
||||
// LoadCodeAssistResponse loadCodeAssist 响应
|
||||
type LoadCodeAssistResponse struct {
|
||||
CloudAICompanionProject string `json:"cloudaicompanionProject"`
|
||||
CurrentTier *TierInfo `json:"currentTier,omitempty"`
|
||||
PaidTier *TierInfo `json:"paidTier,omitempty"`
|
||||
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
|
||||
}
|
||||
|
||||
// GetTier 获取账户类型
|
||||
// 优先返回 paidTier(付费订阅级别),否则返回 currentTier
|
||||
func (r *LoadCodeAssistResponse) GetTier() string {
|
||||
if r.PaidTier != nil && r.PaidTier.ID != "" {
|
||||
return r.PaidTier.ID
|
||||
}
|
||||
if r.CurrentTier != nil {
|
||||
return r.CurrentTier.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Client Antigravity API 客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(proxyURL string) *Client {
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
func isConnectionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查超时错误
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查连接错误(DNS 失败、连接拒绝)
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 URL 错误
|
||||
var urlErr *url.Error
|
||||
return errors.As(err, &urlErr)
|
||||
}
|
||||
|
||||
// shouldFallbackToNextURL 判断是否应切换到下一个 URL
|
||||
// 仅连接错误和 HTTP 429 触发 URL 降级
|
||||
func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
if isConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
return statusCode == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
params.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// GetUserInfo 获取用户信息
|
||||
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("用户信息请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("获取用户信息失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.Unmarshal(bodyBytes, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("用户信息解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
|
||||
reqBody := LoadCodeAssistRequest{}
|
||||
reqBody.Metadata.IDEType = "ANTIGRAVITY"
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:loadCodeAssist"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var loadResp LoadCodeAssistResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &loadResp, rawResp, nil
|
||||
}
|
||||
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
// ModelQuotaInfo 模型配额信息
|
||||
type ModelQuotaInfo struct {
|
||||
RemainingFraction float64 `json:"remainingFraction"`
|
||||
ResetTime string `json:"resetTime,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsRequest fetchAvailableModels 请求
|
||||
type FetchAvailableModelsRequest struct {
|
||||
Project string `json:"project"`
|
||||
}
|
||||
|
||||
// FetchAvailableModelsResponse fetchAvailableModels 响应
|
||||
type FetchAvailableModelsResponse struct {
|
||||
Models map[string]ModelInfo `json:"models"`
|
||||
}
|
||||
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取可用的 URL 列表
|
||||
availableURLs := DefaultURLAvailability.GetAvailableURLs()
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("创建请求失败: %w", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
DefaultURLAvailability.MarkUnavailable(baseURL)
|
||||
log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1])
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp FetchAvailableModelsResponse
|
||||
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析原始 JSON 为 map
|
||||
var rawResp map[string]any
|
||||
_ = json.Unmarshal(respBodyBytes, &rawResp)
|
||||
|
||||
return &modelsResp, rawResp, nil
|
||||
}
|
||||
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
175
backend/internal/pkg/antigravity/gemini_types.go
Normal file
175
backend/internal/pkg/antigravity/gemini_types.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package antigravity
|
||||
|
||||
// Gemini v1internal 请求/响应类型定义
|
||||
|
||||
// V1InternalRequest v1internal 请求包装
|
||||
type V1InternalRequest struct {
|
||||
Project string `json:"project"`
|
||||
RequestID string `json:"requestId"`
|
||||
UserAgent string `json:"userAgent"`
|
||||
RequestType string `json:"requestType,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Request GeminiRequest `json:"request"`
|
||||
}
|
||||
|
||||
// GeminiRequest Gemini 请求内容
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
|
||||
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
|
||||
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiContent Gemini 内容
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role"` // user, model
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
// GeminiPart Gemini 内容部分
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiInlineData Gemini 内联数据(图片等)
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// GeminiFunctionCall Gemini 函数调用
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args any `json:"args,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionResponse Gemini 函数响应
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]any `json:"response"`
|
||||
ID string `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGenerationConfig Gemini 生成配置
|
||||
type GeminiGenerationConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
||||
type GeminiImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||
}
|
||||
|
||||
// GeminiThinkingConfig Gemini thinking 配置
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts"`
|
||||
ThinkingBudget int `json:"thinkingBudget,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiToolDeclaration Gemini 工具声明
|
||||
type GeminiToolDeclaration struct {
|
||||
FunctionDeclarations []GeminiFunctionDecl `json:"functionDeclarations,omitempty"`
|
||||
GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionDecl Gemini 函数声明
|
||||
type GeminiFunctionDecl struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiGoogleSearch Gemini Google 搜索工具
|
||||
type GeminiGoogleSearch struct {
|
||||
EnhancedContent *GeminiEnhancedContent `json:"enhancedContent,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiEnhancedContent 增强内容配置
|
||||
type GeminiEnhancedContent struct {
|
||||
ImageSearch *GeminiImageSearch `json:"imageSearch,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageSearch 图片搜索配置
|
||||
type GeminiImageSearch struct {
|
||||
MaxResultCount int `json:"maxResultCount,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiToolConfig Gemini 工具配置
|
||||
type GeminiToolConfig struct {
|
||||
FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFunctionCallingConfig 函数调用配置
|
||||
type GeminiFunctionCallingConfig struct {
|
||||
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
|
||||
}
|
||||
|
||||
// GeminiSafetySetting Gemini 安全设置
|
||||
type GeminiSafetySetting struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
// V1InternalResponse v1internal 响应包装
|
||||
type V1InternalResponse struct {
|
||||
Response GeminiResponse `json:"response"`
|
||||
ResponseID string `json:"responseId,omitempty"`
|
||||
ModelVersion string `json:"modelVersion,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiResponse Gemini 响应
|
||||
type GeminiResponse struct {
|
||||
Candidates []GeminiCandidate `json:"candidates,omitempty"`
|
||||
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
|
||||
ResponseID string `json:"responseId,omitempty"`
|
||||
ModelVersion string `json:"modelVersion,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiCandidate Gemini 候选响应
|
||||
type GeminiCandidate struct {
|
||||
Content *GeminiContent `json:"content,omitempty"`
|
||||
FinishReason string `json:"finishReason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||
var DefaultSafetySettings = []GeminiSafetySetting{
|
||||
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: "OFF"},
|
||||
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: "OFF"},
|
||||
}
|
||||
|
||||
// DefaultStopSequences 默认停止序列
|
||||
var DefaultStopSequences = []string{
|
||||
"<|user|>",
|
||||
"<|endoftext|>",
|
||||
"<|end_of_turn|>",
|
||||
"[DONE]",
|
||||
"\n\nHuman:",
|
||||
}
|
||||
263
backend/internal/pkg/antigravity/oauth.go
Normal file
263
backend/internal/pkg/antigravity/oauth.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Google OAuth 端点
|
||||
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TokenURL = "https://oauth2.googleapis.com/token"
|
||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
|
||||
// OAuth scopes
|
||||
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
|
||||
"https://www.googleapis.com/auth/userinfo.email " +
|
||||
"https://www.googleapis.com/auth/userinfo.profile " +
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(模拟官方客户端)
|
||||
UserAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
// URL 可用性 TTL(不可用 URL 的恢复时间)
|
||||
URLAvailabilityTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点,按优先级排序
|
||||
// fallback 顺序: sandbox → daily → prod
|
||||
var BaseURLs = []string{
|
||||
"https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox
|
||||
"https://daily-cloudcode-pa.googleapis.com", // daily
|
||||
"https://cloudcode-pa.googleapis.com", // prod
|
||||
}
|
||||
|
||||
// BaseURL 默认 URL(保持向后兼容)
|
||||
var BaseURL = BaseURLs[0]
|
||||
|
||||
// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复)
|
||||
type URLAvailability struct {
|
||||
mu sync.RWMutex
|
||||
unavailable map[string]time.Time // URL -> 恢复时间
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// DefaultURLAvailability 全局 URL 可用性管理器
|
||||
var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL)
|
||||
|
||||
// NewURLAvailability 创建 URL 可用性管理器
|
||||
func NewURLAvailability(ttl time.Duration) *URLAvailability {
|
||||
return &URLAvailability{
|
||||
unavailable: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkUnavailable 标记 URL 临时不可用
|
||||
func (u *URLAvailability) MarkUnavailable(url string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.unavailable[url] = time.Now().Add(u.ttl)
|
||||
}
|
||||
|
||||
// IsAvailable 检查 URL 是否可用
|
||||
func (u *URLAvailability) IsAvailable(url string) bool {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
return time.Now().After(expiry)
|
||||
}
|
||||
|
||||
// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序)
|
||||
func (u *URLAvailability) GetAvailableURLs() []string {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
result := make([]string, 0, len(BaseURLs))
|
||||
for _, url := range BaseURLs {
|
||||
expiry, exists := u.unavailable[url]
|
||||
if !exists || now.After(expiry) {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// OAuthSession 保存 OAuth 授权流程的临时状态
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore OAuth session 存储
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func NewSessionStore() *SessionStore {
|
||||
store := &SessionStore{
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
go store.cleanup()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[sessionID] = session
|
||||
}
|
||||
|
||||
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
session, ok := s.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
}
|
||||
|
||||
func (s *SessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, sessionID)
|
||||
}
|
||||
|
||||
func (s *SessionStore) Stop() {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
close(s.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionStore) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.mu.Lock()
|
||||
for id, session := range s.sessions {
|
||||
if time.Since(session.CreatedAt) > SessionTTL {
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func GenerateState() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeVerifier() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64URLEncode(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateCodeChallenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
func base64URLEncode(data []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
|
||||
}
|
||||
|
||||
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
|
||||
func BuildAuthorizationURL(state, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", Scopes)
|
||||
params.Set("state", state)
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("access_type", "offline")
|
||||
params.Set("prompt", "consent")
|
||||
params.Set("include_granted_scopes", "true")
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用)
|
||||
// 格式:{形容词}-{名词}-{5位随机字符}
|
||||
func GenerateMockProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
|
||||
randBytes, _ := GenerateRandomBytes(7)
|
||||
|
||||
adj := adjectives[int(randBytes[0])%len(adjectives)]
|
||||
noun := nouns[int(randBytes[1])%len(nouns)]
|
||||
|
||||
// 生成 5 位随机字符(a-z0-9)
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
suffix := make([]byte, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
|
||||
}
|
||||
773
backend/internal/pkg/antigravity/request_transformer.go
Normal file
773
backend/internal/pkg/antigravity/request_transformer.go
Normal file
@@ -0,0 +1,773 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
sessionRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
sessionRandMutex sync.Mutex
|
||||
)
|
||||
|
||||
// generateStableSessionID 基于用户消息内容生成稳定的 session ID
|
||||
func generateStableSessionID(contents []GeminiContent) string {
|
||||
// 查找第一个 user 消息的文本
|
||||
for _, content := range contents {
|
||||
if content.Role == "user" && len(content.Parts) > 0 {
|
||||
if text := content.Parts[0].Text; text != "" {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 回退:生成随机 session ID
|
||||
sessionRandMutex.Lock()
|
||||
n := sessionRand.Int63n(9_000_000_000_000_000_000)
|
||||
sessionRandMutex.Unlock()
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
type TransformOptions struct {
|
||||
EnableIdentityPatch bool
|
||||
// IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词;
|
||||
// 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。
|
||||
IdentityPatch string
|
||||
}
|
||||
|
||||
func DefaultTransformOptions() TransformOptions {
|
||||
return TransformOptions{
|
||||
EnableIdentityPatch: true,
|
||||
}
|
||||
}
|
||||
|
||||
// TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式
|
||||
func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) {
|
||||
return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions())
|
||||
}
|
||||
|
||||
// TransformClaudeToGeminiWithOptions 将 Claude 请求转换为 v1internal Gemini 格式(可配置身份补丁等行为)
|
||||
func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, mappedModel string, opts TransformOptions) ([]byte, error) {
|
||||
// 用于存储 tool_use id -> name 映射
|
||||
toolIDToName := make(map[string]string)
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
||||
|
||||
// 1. 构建 contents
|
||||
contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build contents: %w", err)
|
||||
}
|
||||
|
||||
// 2. 构建 systemInstruction
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts)
|
||||
|
||||
// 3. 构建 generationConfig
|
||||
reqForConfig := claudeReq
|
||||
if strippedThinking {
|
||||
// If we had to downgrade thinking blocks to plain text due to missing/invalid signatures,
|
||||
// disable upstream thinking mode to avoid signature/structure validation errors.
|
||||
reqCopy := *claudeReq
|
||||
reqCopy.Thinking = nil
|
||||
reqForConfig = &reqCopy
|
||||
}
|
||||
generationConfig := buildGenerationConfig(reqForConfig)
|
||||
|
||||
// 4. 构建 tools
|
||||
tools := buildTools(claudeReq.Tools)
|
||||
|
||||
// 5. 构建内部请求
|
||||
innerRequest := GeminiRequest{
|
||||
Contents: contents,
|
||||
// 总是设置 toolConfig,与官方客户端一致
|
||||
ToolConfig: &GeminiToolConfig{
|
||||
FunctionCallingConfig: &GeminiFunctionCallingConfig{
|
||||
Mode: "VALIDATED",
|
||||
},
|
||||
},
|
||||
// 总是生成 sessionId,基于用户消息内容
|
||||
SessionID: generateStableSessionID(contents),
|
||||
}
|
||||
|
||||
if systemInstruction != nil {
|
||||
innerRequest.SystemInstruction = systemInstruction
|
||||
}
|
||||
if generationConfig != nil {
|
||||
innerRequest.GenerationConfig = generationConfig
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
innerRequest.Tools = tools
|
||||
}
|
||||
|
||||
// 如果提供了 metadata.user_id,优先使用
|
||||
if claudeReq.Metadata != nil && claudeReq.Metadata.UserID != "" {
|
||||
innerRequest.SessionID = claudeReq.Metadata.UserID
|
||||
}
|
||||
|
||||
// 6. 包装为 v1internal 请求
|
||||
v1Req := V1InternalRequest{
|
||||
Project: projectID,
|
||||
RequestID: "agent-" + uuid.New().String(),
|
||||
UserAgent: "antigravity", // 固定值,与官方客户端一致
|
||||
RequestType: "agent",
|
||||
Model: mappedModel,
|
||||
Request: innerRequest,
|
||||
}
|
||||
|
||||
return json.Marshal(v1Req)
|
||||
}
|
||||
|
||||
// antigravityIdentity Antigravity identity 提示词
|
||||
const antigravityIdentity = `<identity>
|
||||
You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.
|
||||
You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.
|
||||
The USER will send you requests, which you must always prioritize addressing. Along with each USER request, we will attach additional metadata about their current state, such as what files they have open and where their cursor is.
|
||||
This information may or may not be relevant to the coding task, it is up for you to decide.
|
||||
</identity>
|
||||
<communication_style>
|
||||
- **Proactiveness**. As an agent, you are allowed to be proactive, but only in the course of completing the user's task. For example, if the user asks you to add a new component, you can edit the code, verify build and test statuses, and take any other obvious follow-up actions, such as performing additional research. However, avoid surprising the user. For example, if the user asks HOW to approach something, you should answer their question and instead of jumping into editing a file.</communication_style>`
|
||||
|
||||
func defaultIdentityPatch(_ string) string {
|
||||
return antigravityIdentity
|
||||
}
|
||||
|
||||
// GetDefaultIdentityPatch 返回默认的 Antigravity 身份提示词
|
||||
func GetDefaultIdentityPatch() string {
|
||||
return antigravityIdentity
|
||||
}
|
||||
|
||||
// buildSystemInstruction 构建 systemInstruction
|
||||
func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent {
|
||||
var parts []GeminiPart
|
||||
|
||||
// 先解析用户的 system prompt,检测是否已包含 Antigravity identity
|
||||
userHasAntigravityIdentity := false
|
||||
var userSystemParts []GeminiPart
|
||||
|
||||
if len(system) > 0 {
|
||||
// 尝试解析为字符串
|
||||
var sysStr string
|
||||
if err := json.Unmarshal(system, &sysStr); err == nil {
|
||||
if strings.TrimSpace(sysStr) != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr})
|
||||
if strings.Contains(sysStr, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 尝试解析为数组
|
||||
var sysBlocks []SystemBlock
|
||||
if err := json.Unmarshal(system, &sysBlocks); err == nil {
|
||||
for _, block := range sysBlocks {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text})
|
||||
if strings.Contains(block.Text, "You are Antigravity") {
|
||||
userHasAntigravityIdentity = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 仅在用户未提供 Antigravity identity 时注入
|
||||
if opts.EnableIdentityPatch && !userHasAntigravityIdentity {
|
||||
identityPatch := strings.TrimSpace(opts.IdentityPatch)
|
||||
if identityPatch == "" {
|
||||
identityPatch = defaultIdentityPatch(modelName)
|
||||
}
|
||||
parts = append(parts, GeminiPart{Text: identityPatch})
|
||||
}
|
||||
|
||||
// 添加用户的 system prompt
|
||||
parts = append(parts, userSystemParts...)
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &GeminiContent{
|
||||
Role: "user",
|
||||
Parts: parts,
|
||||
}
|
||||
}
|
||||
|
||||
// buildContents 构建 contents
|
||||
func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isThinkingEnabled, allowDummyThought bool) ([]GeminiContent, bool, error) {
|
||||
var contents []GeminiContent
|
||||
strippedThinking := false
|
||||
|
||||
for i, msg := range messages {
|
||||
role := msg.Role
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
parts, strippedThisMsg, err := buildParts(msg.Content, toolIDToName, allowDummyThought)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("build parts for message %d: %w", i, err)
|
||||
}
|
||||
if strippedThisMsg {
|
||||
strippedThinking = true
|
||||
}
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thinking block workaround
|
||||
// 只对最后一条 assistant 消息添加(Pre-fill 场景)
|
||||
// 历史 assistant 消息不能添加没有 signature 的 dummy thinking block
|
||||
if allowDummyThought && role == "model" && isThinkingEnabled && i == len(messages)-1 {
|
||||
hasThoughtPart := false
|
||||
for _, p := range parts {
|
||||
if p.Thought {
|
||||
hasThoughtPart = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasThoughtPart && len(parts) > 0 {
|
||||
// 在开头添加 dummy thinking block
|
||||
parts = append([]GeminiPart{{
|
||||
Text: "Thinking...",
|
||||
Thought: true,
|
||||
ThoughtSignature: dummyThoughtSignature,
|
||||
}}, parts...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
return contents, strippedThinking, nil
|
||||
}
|
||||
|
||||
// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证
|
||||
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const dummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
// buildParts 构建消息的 parts
|
||||
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
||||
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, bool, error) {
|
||||
var parts []GeminiPart
|
||||
strippedThinking := false
|
||||
|
||||
// 尝试解析为字符串
|
||||
var textContent string
|
||||
if err := json.Unmarshal(content, &textContent); err == nil {
|
||||
if textContent != "(no content)" && strings.TrimSpace(textContent) != "" {
|
||||
parts = append(parts, GeminiPart{Text: strings.TrimSpace(textContent)})
|
||||
}
|
||||
return parts, false, nil
|
||||
}
|
||||
|
||||
// 解析为内容块数组
|
||||
var blocks []ContentBlock
|
||||
if err := json.Unmarshal(content, &blocks); err != nil {
|
||||
return nil, false, fmt.Errorf("parse content blocks: %w", err)
|
||||
}
|
||||
|
||||
for _, block := range blocks {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
if block.Text != "(no content)" && strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, GeminiPart{Text: block.Text})
|
||||
}
|
||||
|
||||
case "thinking":
|
||||
part := GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
|
||||
if strings.TrimSpace(block.Thinking) != "" {
|
||||
parts = append(parts, GeminiPart{Text: block.Thinking})
|
||||
}
|
||||
strippedThinking = true
|
||||
continue
|
||||
} else {
|
||||
// Gemini 模型使用 dummy signature
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
case "image":
|
||||
if block.Source != nil && block.Source.Type == "base64" {
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: block.Source.MediaType,
|
||||
Data: block.Source.Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case "tool_use":
|
||||
// 存储 id -> name 映射
|
||||
if block.ID != "" && block.Name != "" {
|
||||
toolIDToName[block.ID] = block.Name
|
||||
}
|
||||
|
||||
part := GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: block.Name,
|
||||
Args: block.Input,
|
||||
ID: block.ID,
|
||||
},
|
||||
}
|
||||
// tool_use 的 signature 处理:
|
||||
// - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
|
||||
// - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
} else if block.Signature != "" && block.Signature != dummyThoughtSignature {
|
||||
part.ThoughtSignature = block.Signature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
|
||||
case "tool_result":
|
||||
// 获取函数名
|
||||
funcName := block.Name
|
||||
if funcName == "" {
|
||||
if name, ok := toolIDToName[block.ToolUseID]; ok {
|
||||
funcName = name
|
||||
} else {
|
||||
funcName = block.ToolUseID
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 content
|
||||
resultContent := parseToolResultContent(block.Content, block.IsError)
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: funcName,
|
||||
Response: map[string]any{
|
||||
"result": resultContent,
|
||||
},
|
||||
ID: block.ToolUseID,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return parts, strippedThinking, nil
|
||||
}
|
||||
|
||||
// parseToolResultContent 解析 tool_result 的 content
|
||||
func parseToolResultContent(content json.RawMessage, isError bool) string {
|
||||
if len(content) == 0 {
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
|
||||
// 尝试解析为字符串
|
||||
var str string
|
||||
if err := json.Unmarshal(content, &str); err == nil {
|
||||
if strings.TrimSpace(str) == "" {
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal(content, &arr); err == nil {
|
||||
var texts []string
|
||||
for _, item := range arr {
|
||||
if text, ok := item["text"].(string); ok {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
result := strings.Join(texts, "\n")
|
||||
if strings.TrimSpace(result) == "" {
|
||||
if isError {
|
||||
return "Tool execution failed with no output."
|
||||
}
|
||||
return "Command executed successfully."
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// 返回原始 JSON
|
||||
return string(content)
|
||||
}
|
||||
|
||||
// buildGenerationConfig 构建 generationConfig
|
||||
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
|
||||
config := &GeminiGenerationConfig{
|
||||
MaxOutputTokens: 64000, // 默认最大输出
|
||||
StopSequences: DefaultStopSequences,
|
||||
}
|
||||
|
||||
// Thinking 配置
|
||||
if req.Thinking != nil && req.Thinking.Type == "enabled" {
|
||||
config.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if req.Thinking.BudgetTokens > 0 {
|
||||
budget := req.Thinking.BudgetTokens
|
||||
// gemini-2.5-flash 上限 24576
|
||||
if strings.Contains(req.Model, "gemini-2.5-flash") && budget > 24576 {
|
||||
budget = 24576
|
||||
}
|
||||
config.ThinkingConfig.ThinkingBudget = budget
|
||||
}
|
||||
}
|
||||
|
||||
// 其他参数
|
||||
if req.Temperature != nil {
|
||||
config.Temperature = req.Temperature
|
||||
}
|
||||
if req.TopP != nil {
|
||||
config.TopP = req.TopP
|
||||
}
|
||||
if req.TopK != nil {
|
||||
config.TopK = req.TopK
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// buildTools 构建 tools
|
||||
func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否有 web_search 工具
|
||||
hasWebSearch := false
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "web_search" {
|
||||
hasWebSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasWebSearch {
|
||||
// Web Search 工具映射
|
||||
return []GeminiToolDeclaration{{
|
||||
GoogleSearch: &GeminiGoogleSearch{
|
||||
EnhancedContent: &GeminiEnhancedContent{
|
||||
ImageSearch: &GeminiImageSearch{
|
||||
MaxResultCount: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}}
|
||||
}
|
||||
|
||||
// 普通工具
|
||||
var funcDecls []GeminiFunctionDecl
|
||||
for _, tool := range tools {
|
||||
// 跳过无效工具名称
|
||||
if strings.TrimSpace(tool.Name) == "" {
|
||||
log.Printf("Warning: skipping tool with empty name")
|
||||
continue
|
||||
}
|
||||
|
||||
var description string
|
||||
var inputSchema map[string]any
|
||||
|
||||
// 检查是否为 custom 类型工具 (MCP)
|
||||
if tool.Type == "custom" {
|
||||
if tool.Custom == nil || tool.Custom.InputSchema == nil {
|
||||
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
|
||||
continue
|
||||
}
|
||||
description = tool.Custom.Description
|
||||
inputSchema = tool.Custom.InputSchema
|
||||
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
description = tool.Description
|
||||
inputSchema = tool.InputSchema
|
||||
}
|
||||
|
||||
// 清理 JSON Schema
|
||||
params := cleanJSONSchema(inputSchema)
|
||||
// 为 nil schema 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "OBJECT",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
||||
Name: tool.Name,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
})
|
||||
}
|
||||
|
||||
if len(funcDecls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []GeminiToolDeclaration{{
|
||||
FunctionDeclarations: funcDecls,
|
||||
}}
|
||||
}
|
||||
|
||||
// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
|
||||
// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
|
||||
func cleanJSONSchema(schema map[string]any) map[string]any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
cleaned := cleanSchemaValue(schema, "$")
|
||||
result, ok := cleaned.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保有 type 字段(默认 OBJECT)
|
||||
if _, hasType := result["type"]; !hasType {
|
||||
result["type"] = "OBJECT"
|
||||
}
|
||||
|
||||
// 确保有 properties 字段(默认空对象)
|
||||
if _, hasProps := result["properties"]; !hasProps {
|
||||
result["properties"] = make(map[string]any)
|
||||
}
|
||||
|
||||
// 验证 required 中的字段都存在于 properties 中
|
||||
if required, ok := result["required"].([]any); ok {
|
||||
if props, ok := result["properties"].(map[string]any); ok {
|
||||
validRequired := make([]any, 0, len(required))
|
||||
for _, r := range required {
|
||||
if reqName, ok := r.(string); ok {
|
||||
if _, exists := props[reqName]; exists {
|
||||
validRequired = append(validRequired, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(validRequired) > 0 {
|
||||
result["required"] = validRequired
|
||||
} else {
|
||||
delete(result, "required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
var schemaValidationKeys = map[string]bool{
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
}
|
||||
|
||||
var warnedSchemaKeys sync.Map
|
||||
|
||||
func schemaCleaningWarningsEnabled() bool {
|
||||
// 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
|
||||
if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
|
||||
switch strings.ToLower(v) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
}
|
||||
}
|
||||
// 默认:非 release 模式下输出(debug/test)
|
||||
return gin.Mode() != gin.ReleaseMode
|
||||
}
|
||||
|
||||
func warnSchemaKeyRemovedOnce(key, path string) {
|
||||
if !schemaCleaningWarningsEnabled() {
|
||||
return
|
||||
}
|
||||
if !schemaValidationKeys[key] {
|
||||
return
|
||||
}
|
||||
if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
|
||||
return
|
||||
}
|
||||
log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
|
||||
}
|
||||
|
||||
// excludedSchemaKeys 不支持的 schema 字段
|
||||
// 基于 Claude API (Vertex AI) 的实际支持情况
|
||||
// 支持: type, description, enum, properties, required, additionalProperties, items
|
||||
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
|
||||
var excludedSchemaKeys = map[string]bool{
|
||||
// 元 schema 字段
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
|
||||
// 字符串验证(Gemini 不支持)
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
|
||||
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
|
||||
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
|
||||
// 组合 schema(Gemini 不支持)
|
||||
"oneOf": true,
|
||||
"anyOf": true,
|
||||
"allOf": true,
|
||||
"not": true,
|
||||
"if": true,
|
||||
"then": true,
|
||||
"else": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
|
||||
// 对象验证(仅保留 properties/required/additionalProperties)
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
|
||||
// 其他不支持的字段
|
||||
"default": true,
|
||||
"const": true,
|
||||
"examples": true,
|
||||
"deprecated": true,
|
||||
"readOnly": true,
|
||||
"writeOnly": true,
|
||||
"contentMediaType": true,
|
||||
"contentEncoding": true,
|
||||
|
||||
// Claude 特有字段
|
||||
"strict": true,
|
||||
}
|
||||
|
||||
// cleanSchemaValue 递归清理 schema 值
|
||||
func cleanSchemaValue(value any, path string) any {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
// 跳过不支持的字段
|
||||
if excludedSchemaKeys[k] {
|
||||
warnSchemaKeyRemovedOnce(k, path)
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 type 字段
|
||||
if k == "type" {
|
||||
result[k] = cleanTypeValue(val)
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
|
||||
if k == "format" {
|
||||
if formatStr, ok := val.(string); ok {
|
||||
// Gemini 只支持 date-time, date, time
|
||||
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
|
||||
result[k] = val
|
||||
}
|
||||
// 其他 format 值直接跳过
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
|
||||
if k == "additionalProperties" {
|
||||
if boolVal, ok := val.(bool); ok {
|
||||
result[k] = boolVal
|
||||
} else {
|
||||
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
||||
result[k] = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 递归清理所有值
|
||||
result[k] = cleanSchemaValue(val, path+"."+k)
|
||||
}
|
||||
return result
|
||||
|
||||
case []any:
|
||||
// 递归处理数组中的每个元素
|
||||
cleaned := make([]any, 0, len(v))
|
||||
for i, item := range v {
|
||||
cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
|
||||
}
|
||||
return cleaned
|
||||
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// cleanTypeValue 处理 type 字段,转换为大写
|
||||
func cleanTypeValue(value any) any {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return strings.ToUpper(v)
|
||||
case []any:
|
||||
// 联合类型 ["string", "null"] -> 取第一个非 null 类型
|
||||
for _, t := range v {
|
||||
if ts, ok := t.(string); ok && ts != "null" {
|
||||
return strings.ToUpper(ts)
|
||||
}
|
||||
}
|
||||
// 如果只有 null,返回 STRING
|
||||
return "STRING"
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
244
backend/internal/pkg/antigravity/request_transformer_test.go
Normal file
244
backend/internal/pkg/antigravity/request_transformer_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
allowDummyThought bool
|
||||
expectedParts int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Claude model - downgrade thinking to text without signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 3, // thinking 内容降级为普通 text part
|
||||
description: "Claude模型缺少signature时应将thinking降级为text,并在上层禁用thinking mode",
|
||||
},
|
||||
{
|
||||
name: "Claude model - preserve thinking block with signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "sig_real_123"},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 3,
|
||||
description: "Claude模型应透传带 signature 的 thinking block(用于 Vertex 签名链路)",
|
||||
},
|
||||
{
|
||||
name: "Gemini model - use dummy signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: true,
|
||||
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
|
||||
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
|
||||
if len(parts) != tt.expectedParts {
|
||||
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
|
||||
}
|
||||
|
||||
switch tt.name {
|
||||
case "Claude model - preserve thinking block with signature":
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
if !parts[1].Thought || parts[1].ThoughtSignature != "sig_real_123" {
|
||||
t.Fatalf("expected thought part with signature sig_real_123, got thought=%v signature=%q",
|
||||
parts[1].Thought, parts[1].ThoughtSignature)
|
||||
}
|
||||
case "Claude model - downgrade thinking to text without signature":
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
if parts[1].Thought {
|
||||
t.Fatalf("expected downgraded text part, got thought=%v signature=%q",
|
||||
parts[1].Thought, parts[1].ThoughtSignature)
|
||||
}
|
||||
if parts[1].Text != "Let me think..." {
|
||||
t.Fatalf("expected downgraded text %q, got %q", "Let me think...", parts[1].Text)
|
||||
}
|
||||
case "Gemini model - use dummy signature":
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy thought signature, got thought=%v signature=%q",
|
||||
parts[1].Thought, parts[1].ThoughtSignature)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
|
||||
content := `[
|
||||
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
|
||||
]`
|
||||
|
||||
t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
if parts[0].ThoughtSignature != dummyThoughtSignature {
|
||||
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Claude model - preserve valid signature for tool_use", func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, false)
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].FunctionCall == nil {
|
||||
t.Fatalf("expected 1 functionCall part, got %+v", parts)
|
||||
}
|
||||
// Claude 模型应透传有效的 signature(Vertex/Google 需要完整签名链路)
|
||||
if parts[0].ThoughtSignature != "sig_tool_abc" {
|
||||
t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
|
||||
func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []ClaudeTool
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tool format",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "mcp_tool",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "MCP tool description",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"param": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从Custom字段读取description和input_schema",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "standard_tool",
|
||||
Description: "Standard tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "custom_tool",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Custom tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
|
||||
description: "混合标准和custom工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil Custom field",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
// Custom 为 nil
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "Custom字段为nil的custom工具应该被跳过",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil InputSchema",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Invalid",
|
||||
// InputSchema 为 nil
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "InputSchema为nil的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildTools(tt.tools)
|
||||
|
||||
if len(result) != tt.expectedLen {
|
||||
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
|
||||
}
|
||||
|
||||
// 验证function declarations存在
|
||||
if len(result) > 0 && result[0].FunctionDeclarations != nil {
|
||||
if len(result[0].FunctionDeclarations) != len(tt.tools) {
|
||||
t.Errorf("%s: got %d function declarations, want %d",
|
||||
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
273
backend/internal/pkg/antigravity/response_transformer.go
Normal file
273
backend/internal/pkg/antigravity/response_transformer.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *ClaudeUsage, error) {
|
||||
// 解包 v1internal 响应
|
||||
var v1Resp V1InternalResponse
|
||||
if err := json.Unmarshal(geminiResp, &v1Resp); err != nil {
|
||||
// 尝试直接解析为 GeminiResponse
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
|
||||
return nil, nil, fmt.Errorf("parse gemini response: %w", err)
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
// 使用处理器转换
|
||||
processor := NewNonStreamingProcessor()
|
||||
claudeResp := processor.Process(&v1Resp.Response, v1Resp.ResponseID, originalModel)
|
||||
|
||||
// 序列化
|
||||
respBytes, err := json.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal claude response: %w", err)
|
||||
}
|
||||
|
||||
return respBytes, &claudeResp.Usage, nil
|
||||
}
|
||||
|
||||
// NonStreamingProcessor 非流式响应处理器
|
||||
type NonStreamingProcessor struct {
|
||||
contentBlocks []ClaudeContentItem
|
||||
textBuilder string
|
||||
thinkingBuilder string
|
||||
thinkingSignature string
|
||||
trailingSignature string
|
||||
hasToolCall bool
|
||||
}
|
||||
|
||||
// NewNonStreamingProcessor 创建非流式响应处理器
|
||||
func NewNonStreamingProcessor() *NonStreamingProcessor {
|
||||
return &NonStreamingProcessor{
|
||||
contentBlocks: make([]ClaudeContentItem, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理 Gemini 响应
|
||||
func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
|
||||
// 获取 parts
|
||||
var parts []GeminiPart
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
parts = geminiResp.Candidates[0].Content.Parts
|
||||
}
|
||||
|
||||
// 处理所有 parts
|
||||
for _, part := range parts {
|
||||
p.processPart(&part)
|
||||
}
|
||||
|
||||
// 刷新剩余内容
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
return p.buildResponse(geminiResp, responseID, originalModel)
|
||||
}
|
||||
|
||||
// processPart 处理单个 part
|
||||
func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
|
||||
signature := part.ThoughtSignature
|
||||
|
||||
// 1. FunctionCall 处理
|
||||
if part.FunctionCall != nil {
|
||||
p.flushThinking()
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.hasToolCall = true
|
||||
|
||||
// 生成 tool_use id
|
||||
toolID := part.FunctionCall.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
|
||||
}
|
||||
|
||||
item := ClaudeContentItem{
|
||||
Type: "tool_use",
|
||||
ID: toolID,
|
||||
Name: part.FunctionCall.Name,
|
||||
Input: part.FunctionCall.Args,
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
item.Signature = signature
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, item)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
// Thinking part
|
||||
p.flushText()
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.flushThinking()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.thinkingBuilder += part.Text
|
||||
if signature != "" {
|
||||
p.thinkingSignature = signature
|
||||
}
|
||||
} else {
|
||||
// 普通 Text
|
||||
if part.Text == "" {
|
||||
// 空 text 带签名 - 暂存
|
||||
if signature != "" {
|
||||
p.trailingSignature = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
p.flushThinking()
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: p.trailingSignature,
|
||||
})
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
p.textBuilder += part.Text
|
||||
|
||||
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
|
||||
if signature != "" {
|
||||
p.flushText()
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Signature: signature,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. InlineData (Image) 处理
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
p.flushThinking()
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
p.textBuilder += markdownImg
|
||||
p.flushText()
|
||||
}
|
||||
}
|
||||
|
||||
// flushText 刷新 text builder
|
||||
func (p *NonStreamingProcessor) flushText() {
|
||||
if p.textBuilder == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "text",
|
||||
Text: p.textBuilder,
|
||||
})
|
||||
p.textBuilder = ""
|
||||
}
|
||||
|
||||
// flushThinking 刷新 thinking builder
|
||||
func (p *NonStreamingProcessor) flushThinking() {
|
||||
if p.thinkingBuilder == "" && p.thinkingSignature == "" {
|
||||
return
|
||||
}
|
||||
|
||||
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
|
||||
Type: "thinking",
|
||||
Thinking: p.thinkingBuilder,
|
||||
Signature: p.thinkingSignature,
|
||||
})
|
||||
p.thinkingBuilder = ""
|
||||
p.thinkingSignature = ""
|
||||
}
|
||||
|
||||
// buildResponse 构建最终响应
|
||||
func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
|
||||
var finishReason string
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason = geminiResp.Candidates[0].FinishReason
|
||||
}
|
||||
|
||||
stopReason := "end_turn"
|
||||
if p.hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
} else if finishReason == "MAX_TOKENS" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
usage := ClaudeUsage{}
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
// 生成响应 ID
|
||||
respID := responseID
|
||||
if respID == "" {
|
||||
respID = geminiResp.ResponseID
|
||||
}
|
||||
if respID == "" {
|
||||
respID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
return &ClaudeResponse{
|
||||
ID: respID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: originalModel,
|
||||
Content: p.contentBlocks,
|
||||
StopReason: stopReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
464
backend/internal/pkg/antigravity/stream_transformer.go
Normal file
464
backend/internal/pkg/antigravity/stream_transformer.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BlockType 内容块类型
|
||||
type BlockType int
|
||||
|
||||
const (
|
||||
BlockTypeNone BlockType = iota
|
||||
BlockTypeText
|
||||
BlockTypeThinking
|
||||
BlockTypeFunction
|
||||
)
|
||||
|
||||
// StreamingProcessor 流式响应处理器
|
||||
type StreamingProcessor struct {
|
||||
blockType BlockType
|
||||
blockIndex int
|
||||
messageStartSent bool
|
||||
messageStopSent bool
|
||||
usedTool bool
|
||||
pendingSignature string
|
||||
trailingSignature string
|
||||
originalModel string
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
}
|
||||
|
||||
// NewStreamingProcessor 创建流式响应处理器
|
||||
func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
||||
return &StreamingProcessor{
|
||||
blockType: BlockTypeNone,
|
||||
originalModel: originalModel,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" || data == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
var v1Resp V1InternalResponse
|
||||
if err := json.Unmarshal([]byte(data), &v1Resp); err != nil {
|
||||
// 尝试直接解析为 GeminiResponse
|
||||
var directResp GeminiResponse
|
||||
if err2 := json.Unmarshal([]byte(data), &directResp); err2 != nil {
|
||||
return nil
|
||||
}
|
||||
v1Resp.Response = directResp
|
||||
v1Resp.ResponseID = directResp.ResponseID
|
||||
v1Resp.ModelVersion = directResp.ModelVersion
|
||||
}
|
||||
|
||||
geminiResp := &v1Resp.Response
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
// 发送 message_start
|
||||
if !p.messageStartSent {
|
||||
_, _ = result.Write(p.emitMessageStart(&v1Resp))
|
||||
}
|
||||
|
||||
// 更新 usage
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
_, _ = result.Write(p.processPart(&part))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否结束
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
if finishReason != "" {
|
||||
_, _ = result.Write(p.emitFinish(finishReason))
|
||||
}
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// Finish 结束处理,返回最终事件和用量
|
||||
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
}
|
||||
|
||||
// emitMessageStart 发送 message_start 事件
|
||||
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
|
||||
if p.messageStartSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{}
|
||||
if v1Resp.Response.UsageMetadata != nil {
|
||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
responseID := v1Resp.ResponseID
|
||||
if responseID == "" {
|
||||
responseID = v1Resp.Response.ResponseID
|
||||
}
|
||||
if responseID == "" {
|
||||
responseID = "msg_" + generateRandomID()
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"id": responseID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": p.originalModel,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
p.messageStartSent = true
|
||||
return p.formatSSE("message_start", event)
|
||||
}
|
||||
|
||||
// processPart 处理单个 part
|
||||
func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
var result bytes.Buffer
|
||||
signature := part.ThoughtSignature
|
||||
|
||||
// 1. FunctionCall 处理
|
||||
if part.FunctionCall != nil {
|
||||
// 先处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
_, _ = result.Write(p.processThinking(part.Text, signature))
|
||||
} else {
|
||||
_, _ = result.Write(p.processText(part.Text, signature))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. InlineData (Image) 处理
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
_, _ = result.Write(p.processText(markdownImg, ""))
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processThinking 处理 thinking
|
||||
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 开始或继续 thinking 块
|
||||
if p.blockType != BlockTypeThinking {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
|
||||
"thinking": text,
|
||||
}))
|
||||
}
|
||||
|
||||
// 暂存签名
|
||||
if signature != "" {
|
||||
p.pendingSignature = signature
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processText 处理普通 text
|
||||
func (p *StreamingProcessor) processText(text, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 空 text 带签名 - 暂存
|
||||
if text == "" {
|
||||
if signature != "" {
|
||||
p.trailingSignature = signature
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 非空 text 带签名 - 特殊处理
|
||||
if signature != "" {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": text,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 普通 text (无签名)
|
||||
if p.blockType != BlockTypeText {
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
|
||||
"text": text,
|
||||
}))
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// processFunctionCall 处理 function call
|
||||
func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
p.usedTool = true
|
||||
|
||||
toolID := fc.ID
|
||||
if toolID == "" {
|
||||
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": toolID,
|
||||
"name": fc.Name,
|
||||
"input": map[string]any{},
|
||||
}
|
||||
|
||||
if signature != "" {
|
||||
toolUse["signature"] = signature
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
|
||||
|
||||
// 发送 input_json_delta
|
||||
if fc.Args != nil {
|
||||
argsJSON, _ := json.Marshal(fc.Args)
|
||||
_, _ = result.Write(p.emitDelta("input_json_delta", map[string]any{
|
||||
"partial_json": string(argsJSON),
|
||||
}))
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// startBlock 开始新的内容块
|
||||
func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[string]any) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
if p.blockType != BlockTypeNone {
|
||||
_, _ = result.Write(p.endBlock())
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": p.blockIndex,
|
||||
"content_block": contentBlock,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("content_block_start", event))
|
||||
p.blockType = blockType
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// endBlock 结束当前内容块
|
||||
func (p *StreamingProcessor) endBlock() []byte {
|
||||
if p.blockType == BlockTypeNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
|
||||
// Thinking 块结束时发送暂存的签名
|
||||
if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
|
||||
"signature": p.pendingSignature,
|
||||
}))
|
||||
p.pendingSignature = ""
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": p.blockIndex,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("content_block_stop", event))
|
||||
|
||||
p.blockIndex++
|
||||
p.blockType = BlockTypeNone
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// emitDelta 发送 delta 事件
|
||||
func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string]any) []byte {
|
||||
delta := map[string]any{
|
||||
"type": deltaType,
|
||||
}
|
||||
for k, v := range deltaContent {
|
||||
delta[k] = v
|
||||
}
|
||||
|
||||
event := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": p.blockIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
|
||||
return p.formatSSE("content_block_delta", event)
|
||||
}
|
||||
|
||||
// emitEmptyThinkingWithSignature 发送空 thinking 块承载签名
|
||||
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
|
||||
"thinking": "",
|
||||
}))
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
|
||||
"signature": signature,
|
||||
}))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// emitFinish 发送结束事件
|
||||
func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 关闭最后一个块
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 确定 stop_reason
|
||||
stopReason := "end_turn"
|
||||
if p.usedTool {
|
||||
stopReason = "tool_use"
|
||||
} else if finishReason == "MAX_TOKENS" {
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
if !p.messageStopSent {
|
||||
stopEvent := map[string]any{
|
||||
"type": "message_stop",
|
||||
}
|
||||
_, _ = result.Write(p.formatSSE("message_stop", stopEvent))
|
||||
p.messageStopSent = true
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// formatSSE 格式化 SSE 事件
|
||||
func (p *StreamingProcessor) formatSSE(eventType string, data any) []byte {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(jsonData)))
|
||||
}
|
||||
Reference in New Issue
Block a user