merge: 合并 test 分支到 test-dev,解决冲突
解决的冲突文件: - wire_gen.go: 合并 ConcurrencyService/CRSSyncService 参数和 userAttributeHandler - gateway_handler.go: 合并 pkg/errors 和 antigravity 导入 - gateway_service.go: 合并 validateUpstreamBaseURL 和 GetAvailableModels - config.example.yaml: 合并 billing/turnstile 配置和额外 gateway 选项 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -37,12 +37,26 @@ type ClaudeMetadata struct {
|
||||
}
|
||||
|
||||
// ClaudeTool Claude 工具定义
|
||||
// 支持两种格式:
|
||||
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
|
||||
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
|
||||
type ClaudeTool struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"` // 标准格式使用
|
||||
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
|
||||
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
|
||||
}
|
||||
|
||||
// CustomToolSpec MCP custom 工具规格
|
||||
type CustomToolSpec struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
|
||||
type ClaudeCustomToolSpec = CustomToolSpec
|
||||
|
||||
// SystemBlock system prompt 数组形式的元素
|
||||
type SystemBlock struct {
|
||||
Type string `json:"type"`
|
||||
@@ -124,3 +138,91 @@ type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// modelDef Antigravity 模型定义(内部使用)
|
||||
type modelDef struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
CreatedAt string // 仅 Claude API 格式使用
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Claude 模型
|
||||
var claudeModels = []modelDef{
|
||||
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
var geminiModels = []modelDef{
|
||||
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
// ========== Claude API 格式 (/v1/models) ==========
|
||||
|
||||
// ClaudeModel Claude API 模型格式
|
||||
type ClaudeModel struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// DefaultModels 返回 Claude API 格式的模型列表(Claude + Gemini)
|
||||
func DefaultModels() []ClaudeModel {
|
||||
all := append(claudeModels, geminiModels...)
|
||||
result := make([]ClaudeModel, len(all))
|
||||
for i, m := range all {
|
||||
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
|
||||
|
||||
// GeminiModel Gemini v1beta 模型格式
|
||||
type GeminiModel struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModelsListResponse Gemini v1beta 模型列表响应
|
||||
type GeminiModelsListResponse struct {
|
||||
Models []GeminiModel `json:"models"`
|
||||
}
|
||||
|
||||
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
|
||||
|
||||
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
|
||||
func DefaultGeminiModels() []GeminiModel {
|
||||
result := make([]GeminiModel, len(geminiModels))
|
||||
for i, m := range geminiModels {
|
||||
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
|
||||
func FallbackGeminiModelsList() GeminiModelsListResponse {
|
||||
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
|
||||
}
|
||||
|
||||
// FallbackGeminiModel 返回单个模型信息(v1beta 格式)
|
||||
func FallbackGeminiModel(model string) GeminiModel {
|
||||
if model == "" {
|
||||
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
name := model
|
||||
if len(model) < 7 || model[:7] != "models/" {
|
||||
name = "models/" + model
|
||||
}
|
||||
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package antigravity
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -13,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
||||
// 用于存储 tool_use id -> name 映射
|
||||
toolIDToName := make(map[string]string)
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
||||
|
||||
// 检测是否启用 thinking
|
||||
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
|
||||
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
|
||||
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
|
||||
|
||||
// 1. 构建 contents
|
||||
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
if err != nil {
|
||||
@@ -30,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
|
||||
|
||||
// 3. 构建 generationConfig
|
||||
generationConfig := buildGenerationConfig(claudeReq)
|
||||
reqForGen := claudeReq
|
||||
if requestedThinkingEnabled && !allowDummyThought {
|
||||
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
|
||||
// shallow copy to avoid mutating caller's request
|
||||
clone := *claudeReq
|
||||
clone.Thinking = nil
|
||||
reqForGen = &clone
|
||||
}
|
||||
generationConfig := buildGenerationConfig(reqForGen)
|
||||
|
||||
// 4. 构建 tools
|
||||
tools := buildTools(claudeReq.Tools)
|
||||
@@ -147,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
||||
if !hasThoughtPart && len(parts) > 0 {
|
||||
// 在开头添加 dummy thinking block
|
||||
parts = append([]GeminiPart{{
|
||||
Text: "Thinking...",
|
||||
Thought: true,
|
||||
Text: "Thinking...",
|
||||
Thought: true,
|
||||
ThoughtSignature: dummyThoughtSignature,
|
||||
}}, parts...)
|
||||
}
|
||||
}
|
||||
@@ -170,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
||||
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const dummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
// isValidThoughtSignature 验证 thought signature 是否有效
|
||||
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
|
||||
func isValidThoughtSignature(signature string) bool {
|
||||
// 空字符串无效
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
|
||||
// 参考 Claude API 文档和实际观察到的有效 signature
|
||||
if len(signature) < 40 {
|
||||
log.Printf("[Debug] Signature too short: len=%d", len(signature))
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否是有效的 base64 字符
|
||||
// base64 字符集: A-Z, a-z, 0-9, +, /, =
|
||||
for i, c := range signature {
|
||||
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
|
||||
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
|
||||
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// buildParts 构建消息的 parts
|
||||
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
||||
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
|
||||
@@ -198,15 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
}
|
||||
|
||||
case "thinking":
|
||||
part := GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
if allowDummyThought {
|
||||
// Gemini 模型可以使用 dummy signature
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
ThoughtSignature: dummyThoughtSignature,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
part.ThoughtSignature = block.Signature
|
||||
|
||||
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
|
||||
signature := strings.TrimSpace(block.Signature)
|
||||
if signature == "" || signature == dummyThoughtSignature {
|
||||
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
|
||||
continue
|
||||
}
|
||||
parts = append(parts, part)
|
||||
if !isValidThoughtSignature(signature) {
|
||||
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
ThoughtSignature: signature,
|
||||
})
|
||||
|
||||
case "image":
|
||||
if block.Source != nil && block.Source.Type == "base64" {
|
||||
@@ -231,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
ID: block.ID,
|
||||
},
|
||||
}
|
||||
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
|
||||
if block.Signature != "" {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if allowDummyThought {
|
||||
// 只有 Gemini 模型使用 dummy signature
|
||||
// Claude 模型不设置 signature(避免验证问题)
|
||||
if allowDummyThought {
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
parts = append(parts, part)
|
||||
@@ -378,13 +433,53 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
|
||||
// 普通工具
|
||||
var funcDecls []GeminiFunctionDecl
|
||||
for _, tool := range tools {
|
||||
for i, tool := range tools {
|
||||
// 跳过无效工具名称
|
||||
if strings.TrimSpace(tool.Name) == "" {
|
||||
log.Printf("Warning: skipping tool with empty name")
|
||||
continue
|
||||
}
|
||||
|
||||
var description string
|
||||
var inputSchema map[string]any
|
||||
|
||||
// 检查是否为 custom 类型工具 (MCP)
|
||||
if tool.Type == "custom" {
|
||||
if tool.Custom == nil || tool.Custom.InputSchema == nil {
|
||||
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
|
||||
continue
|
||||
}
|
||||
description = tool.Custom.Description
|
||||
inputSchema = tool.Custom.InputSchema
|
||||
|
||||
// 调试日志:记录 custom 工具的 schema
|
||||
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
|
||||
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
|
||||
}
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
description = tool.Description
|
||||
inputSchema = tool.InputSchema
|
||||
}
|
||||
|
||||
// 清理 JSON Schema
|
||||
params := cleanJSONSchema(tool.InputSchema)
|
||||
params := cleanJSONSchema(inputSchema)
|
||||
// 为 nil schema 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "OBJECT",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
// 调试日志:记录清理后的 schema
|
||||
if paramsJSON, err := json.Marshal(params); err == nil {
|
||||
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
|
||||
}
|
||||
|
||||
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Description: description,
|
||||
Parameters: params,
|
||||
})
|
||||
}
|
||||
@@ -443,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
|
||||
}
|
||||
|
||||
// excludedSchemaKeys 不支持的 schema 字段
|
||||
// 基于 Claude API (Vertex AI) 的实际支持情况
|
||||
// 支持: type, description, enum, properties, required, additionalProperties, items
|
||||
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
|
||||
var excludedSchemaKeys = map[string]bool{
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
"additionalProperties": true,
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
"uniqueItems": true,
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"pattern": true,
|
||||
"format": true,
|
||||
"default": true,
|
||||
"strict": true,
|
||||
"const": true,
|
||||
"examples": true,
|
||||
"deprecated": true,
|
||||
"readOnly": true,
|
||||
"writeOnly": true,
|
||||
"contentMediaType": true,
|
||||
"contentEncoding": true,
|
||||
// 元 schema 字段
|
||||
"$schema": true,
|
||||
"$id": true,
|
||||
"$ref": true,
|
||||
|
||||
// 字符串验证(Gemini 不支持)
|
||||
"minLength": true,
|
||||
"maxLength": true,
|
||||
"pattern": true,
|
||||
|
||||
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"minimum": true,
|
||||
"maximum": true,
|
||||
"exclusiveMinimum": true,
|
||||
"exclusiveMaximum": true,
|
||||
"multipleOf": true,
|
||||
|
||||
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
|
||||
"uniqueItems": true,
|
||||
"minItems": true,
|
||||
"maxItems": true,
|
||||
|
||||
// 组合 schema(Gemini 不支持)
|
||||
"oneOf": true,
|
||||
"anyOf": true,
|
||||
"allOf": true,
|
||||
"not": true,
|
||||
"if": true,
|
||||
"then": true,
|
||||
"else": true,
|
||||
"$defs": true,
|
||||
"definitions": true,
|
||||
|
||||
// 对象验证(仅保留 properties/required/additionalProperties)
|
||||
"minProperties": true,
|
||||
"maxProperties": true,
|
||||
"patternProperties": true,
|
||||
"propertyNames": true,
|
||||
"dependencies": true,
|
||||
"dependentSchemas": true,
|
||||
"dependentRequired": true,
|
||||
|
||||
// 其他不支持的字段
|
||||
"default": true,
|
||||
"const": true,
|
||||
"examples": true,
|
||||
"deprecated": true,
|
||||
"readOnly": true,
|
||||
"writeOnly": true,
|
||||
"contentMediaType": true,
|
||||
"contentEncoding": true,
|
||||
|
||||
// Claude 特有字段
|
||||
"strict": true,
|
||||
}
|
||||
|
||||
// cleanSchemaValue 递归清理 schema 值
|
||||
@@ -487,6 +615,31 @@ func cleanSchemaValue(value any) any {
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
|
||||
if k == "format" {
|
||||
if formatStr, ok := val.(string); ok {
|
||||
// Gemini 只支持 date-time, date, time
|
||||
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
|
||||
result[k] = val
|
||||
}
|
||||
// 其他 format 值直接跳过
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
|
||||
if k == "additionalProperties" {
|
||||
if boolVal, ok := val.(bool); ok {
|
||||
result[k] = boolVal
|
||||
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
|
||||
} else {
|
||||
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
||||
result[k] = false
|
||||
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 递归清理所有值
|
||||
result[k] = cleanSchemaValue(val)
|
||||
}
|
||||
|
||||
179
backend/internal/pkg/antigravity/request_transformer_test.go
Normal file
179
backend/internal/pkg/antigravity/request_transformer_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
|
||||
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
allowDummyThought bool
|
||||
expectedParts int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Claude model - skip thinking block without signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 2, // 只有两个text block
|
||||
description: "Claude模型应该跳过无signature的thinking block",
|
||||
},
|
||||
{
|
||||
name: "Claude model - keep thinking block with signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: false,
|
||||
expectedParts: 3, // 三个block都保留
|
||||
description: "Claude模型应该保留有signature的thinking block",
|
||||
},
|
||||
{
|
||||
name: "Gemini model - use dummy signature",
|
||||
content: `[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
|
||||
{"type": "text", "text": "World"}
|
||||
]`,
|
||||
allowDummyThought: true,
|
||||
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
|
||||
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
toolIDToName := make(map[string]string)
|
||||
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("buildParts() error = %v", err)
|
||||
}
|
||||
|
||||
if len(parts) != tt.expectedParts {
|
||||
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
|
||||
func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools []ClaudeTool
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tool format",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"location": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "mcp_tool",
|
||||
Custom: &CustomToolSpec{
|
||||
Description: "MCP tool description",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"param": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从Custom字段读取description和input_schema",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Name: "standard_tool",
|
||||
Description: "Standard tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "custom_tool",
|
||||
Custom: &CustomToolSpec{
|
||||
Description: "Custom tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
|
||||
description: "混合标准和custom工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil Custom field",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
// Custom 为 nil
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "Custom字段为nil的custom工具应该被跳过",
|
||||
},
|
||||
{
|
||||
name: "Invalid custom tool - nil InputSchema",
|
||||
tools: []ClaudeTool{
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
Custom: &CustomToolSpec{
|
||||
Description: "Invalid",
|
||||
// InputSchema 为 nil
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "InputSchema为nil的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildTools(tt.tools)
|
||||
|
||||
if len(result) != tt.expectedLen {
|
||||
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
|
||||
}
|
||||
|
||||
// 验证function declarations存在
|
||||
if len(result) > 0 && result[0].FunctionDeclarations != nil {
|
||||
if len(result[0].FunctionDeclarations) != len(tt.tools) {
|
||||
t.Errorf("%s: got %d function declarations, want %d",
|
||||
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// Claude Code 客户端默认请求头
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
|
||||
158
backend/internal/pkg/errors/errors.go
Normal file
158
backend/internal/pkg/errors/errors.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownCode = http.StatusInternalServerError
|
||||
UnknownReason = ""
|
||||
UnknownMessage = "internal error"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
Code int32 `json:"code"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ApplicationError is the standard error type used to control HTTP responses.
|
||||
//
|
||||
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
|
||||
type ApplicationError struct {
|
||||
Status
|
||||
cause error
|
||||
}
|
||||
|
||||
// Error is kept for backwards compatibility within this package.
|
||||
type Error = ApplicationError
|
||||
|
||||
func (e *ApplicationError) Error() string {
|
||||
if e == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
if e.cause == nil {
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
|
||||
}
|
||||
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
|
||||
}
|
||||
|
||||
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||
func (e *ApplicationError) Unwrap() error { return e.cause }
|
||||
|
||||
// Is matches each error in the chain with the target value.
|
||||
func (e *ApplicationError) Is(err error) bool {
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se.Code == e.Code && se.Reason == e.Reason
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithCause attaches the underlying cause of the error.
|
||||
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
|
||||
err := Clone(e)
|
||||
err.cause = cause
|
||||
return err
|
||||
}
|
||||
|
||||
// WithMetadata deep-copies the given metadata map.
|
||||
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
|
||||
err := Clone(e)
|
||||
if md == nil {
|
||||
err.Metadata = nil
|
||||
return err
|
||||
}
|
||||
err.Metadata = make(map[string]string, len(md))
|
||||
for k, v := range md {
|
||||
err.Metadata[k] = v
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// New returns an error object for the code, message.
|
||||
func New(code int, reason, message string) *ApplicationError {
|
||||
return &ApplicationError{
|
||||
Status: Status{
|
||||
Code: int32(code),
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Newf New(code fmt.Sprintf(format, a...))
|
||||
func Newf(code int, reason, format string, a ...any) *ApplicationError {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Errorf returns an error object for the code, message and error info.
|
||||
func Errorf(code int, reason, format string, a ...any) error {
|
||||
return New(code, reason, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// Code returns the http code for an error.
|
||||
// It supports wrapped errors.
|
||||
func Code(err error) int {
|
||||
if err == nil {
|
||||
return http.StatusOK
|
||||
}
|
||||
return int(FromError(err).Code)
|
||||
}
|
||||
|
||||
// Reason returns the reason for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Reason(err error) string {
|
||||
if err == nil {
|
||||
return UnknownReason
|
||||
}
|
||||
return FromError(err).Reason
|
||||
}
|
||||
|
||||
// Message returns the message for a particular error.
|
||||
// It supports wrapped errors.
|
||||
func Message(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return FromError(err).Message
|
||||
}
|
||||
|
||||
// Clone deep clone error to a new error.
|
||||
func Clone(err *ApplicationError) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var metadata map[string]string
|
||||
if err.Metadata != nil {
|
||||
metadata = make(map[string]string, len(err.Metadata))
|
||||
for k, v := range err.Metadata {
|
||||
metadata[k] = v
|
||||
}
|
||||
}
|
||||
return &ApplicationError{
|
||||
cause: err.cause,
|
||||
Status: Status{
|
||||
Code: err.Code,
|
||||
Reason: err.Reason,
|
||||
Message: err.Message,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// FromError tries to convert an error to *ApplicationError.
|
||||
// It supports wrapped errors.
|
||||
func FromError(err error) *ApplicationError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if se := new(ApplicationError); errors.As(err, &se) {
|
||||
return se
|
||||
}
|
||||
|
||||
// Fall back to a generic internal error.
|
||||
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
|
||||
}
|
||||
168
backend/internal/pkg/errors/errors_test.go
Normal file
168
backend/internal/pkg/errors/errors_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
//go:build unit
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplicationError_Basics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *ApplicationError
|
||||
want Status
|
||||
wantIs bool
|
||||
target error
|
||||
wrapped error
|
||||
}{
|
||||
{
|
||||
name: "new",
|
||||
err: New(400, "BAD_REQUEST", "invalid input"),
|
||||
want: Status{
|
||||
Code: 400,
|
||||
Reason: "BAD_REQUEST",
|
||||
Message: "invalid input",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is_matches_code_and_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "UNAUTHORIZED", "ignored message"),
|
||||
wantIs: true,
|
||||
},
|
||||
{
|
||||
name: "is_does_not_match_reason",
|
||||
err: New(401, "UNAUTHORIZED", "nope"),
|
||||
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
|
||||
target: New(401, "DIFFERENT", "ignored message"),
|
||||
wantIs: false,
|
||||
},
|
||||
{
|
||||
name: "from_error_unwraps_wrapped_application_error",
|
||||
err: New(404, "NOT_FOUND", "missing"),
|
||||
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
|
||||
want: Status{
|
||||
Code: 404,
|
||||
Reason: "NOT_FOUND",
|
||||
Message: "missing",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err != nil {
|
||||
require.Equal(t, tt.want, tt.err.Status)
|
||||
}
|
||||
|
||||
if tt.target != nil {
|
||||
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
|
||||
}
|
||||
|
||||
if tt.wrapped != nil {
|
||||
got := FromError(tt.wrapped)
|
||||
require.Equal(t, tt.want, got.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
md map[string]string
|
||||
}{
|
||||
{name: "non_nil", md: map[string]string{"a": "1"}},
|
||||
{name: "nil", md: nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
|
||||
|
||||
if tt.md == nil {
|
||||
require.Nil(t, appErr.Metadata)
|
||||
return
|
||||
}
|
||||
|
||||
tt.md["a"] = "changed"
|
||||
require.Equal(t, "1", appErr.Metadata["a"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromError_Generic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int32
|
||||
wantReason string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
err: stderrors.New("boom"),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
{
|
||||
name: "wrapped_plain_error",
|
||||
err: fmt.Errorf("wrap: %w", io.EOF),
|
||||
wantCode: UnknownCode,
|
||||
wantReason: UnknownReason,
|
||||
wantMsg: UnknownMessage,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := FromError(tt.err)
|
||||
require.Equal(t, tt.wantCode, got.Code)
|
||||
require.Equal(t, tt.wantReason, got.Reason)
|
||||
require.Equal(t, tt.wantMsg, got.Message)
|
||||
require.Equal(t, tt.err, got.Unwrap())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantBody Status
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: Status{Code: int32(http.StatusOK)},
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: Forbidden("FORBIDDEN", "no access"),
|
||||
wantStatusCode: http.StatusForbidden,
|
||||
wantBody: Status{
|
||||
Code: int32(http.StatusForbidden),
|
||||
Reason: "FORBIDDEN",
|
||||
Message: "no access",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, body := ToHTTP(tt.err)
|
||||
require.Equal(t, tt.wantStatusCode, code)
|
||||
require.Equal(t, tt.wantBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
21
backend/internal/pkg/errors/http.go
Normal file
21
backend/internal/pkg/errors/http.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
|
||||
//
|
||||
// The returned body matches the project's Status shape:
|
||||
// { code, reason, message, metadata }.
|
||||
func ToHTTP(err error) (statusCode int, body Status) {
|
||||
if err == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
appErr := FromError(err)
|
||||
if appErr == nil {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
}
|
||||
114
backend/internal/pkg/errors/types.go
Normal file
114
backend/internal/pkg/errors/types.go
Normal file
@@ -0,0 +1,114 @@
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
import "net/http"
|
||||
|
||||
// BadRequest new BadRequest error that is mapped to a 400 response.
|
||||
func BadRequest(reason, message string) *ApplicationError {
|
||||
return New(http.StatusBadRequest, reason, message)
|
||||
}
|
||||
|
||||
// IsBadRequest determines if err is an error which indicates a BadRequest error.
|
||||
// It supports wrapped errors.
|
||||
func IsBadRequest(err error) bool {
|
||||
return Code(err) == http.StatusBadRequest
|
||||
}
|
||||
|
||||
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
|
||||
func TooManyRequests(reason, message string) *ApplicationError {
|
||||
return New(http.StatusTooManyRequests, reason, message)
|
||||
}
|
||||
|
||||
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
|
||||
// It supports wrapped errors.
|
||||
func IsTooManyRequests(err error) bool {
|
||||
return Code(err) == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// Unauthorized new Unauthorized error that is mapped to a 401 response.
|
||||
func Unauthorized(reason, message string) *ApplicationError {
|
||||
return New(http.StatusUnauthorized, reason, message)
|
||||
}
|
||||
|
||||
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
|
||||
// It supports wrapped errors.
|
||||
func IsUnauthorized(err error) bool {
|
||||
return Code(err) == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// Forbidden new Forbidden error that is mapped to a 403 response.
|
||||
func Forbidden(reason, message string) *ApplicationError {
|
||||
return New(http.StatusForbidden, reason, message)
|
||||
}
|
||||
|
||||
// IsForbidden determines if err is an error which indicates a Forbidden error.
|
||||
// It supports wrapped errors.
|
||||
func IsForbidden(err error) bool {
|
||||
return Code(err) == http.StatusForbidden
|
||||
}
|
||||
|
||||
// NotFound new NotFound error that is mapped to a 404 response.
|
||||
func NotFound(reason, message string) *ApplicationError {
|
||||
return New(http.StatusNotFound, reason, message)
|
||||
}
|
||||
|
||||
// IsNotFound determines if err is an error which indicates an NotFound error.
|
||||
// It supports wrapped errors.
|
||||
func IsNotFound(err error) bool {
|
||||
return Code(err) == http.StatusNotFound
|
||||
}
|
||||
|
||||
// Conflict new Conflict error that is mapped to a 409 response.
|
||||
func Conflict(reason, message string) *ApplicationError {
|
||||
return New(http.StatusConflict, reason, message)
|
||||
}
|
||||
|
||||
// IsConflict determines if err is an error which indicates a Conflict error.
|
||||
// It supports wrapped errors.
|
||||
func IsConflict(err error) bool {
|
||||
return Code(err) == http.StatusConflict
|
||||
}
|
||||
|
||||
// InternalServer new InternalServer error that is mapped to a 500 response.
|
||||
func InternalServer(reason, message string) *ApplicationError {
|
||||
return New(http.StatusInternalServerError, reason, message)
|
||||
}
|
||||
|
||||
// IsInternalServer determines if err is an error which indicates an Internal error.
|
||||
// It supports wrapped errors.
|
||||
func IsInternalServer(err error) bool {
|
||||
return Code(err) == http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
|
||||
func ServiceUnavailable(reason, message string) *ApplicationError {
|
||||
return New(http.StatusServiceUnavailable, reason, message)
|
||||
}
|
||||
|
||||
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
|
||||
// It supports wrapped errors.
|
||||
func IsServiceUnavailable(err error) bool {
|
||||
return Code(err) == http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
|
||||
func GatewayTimeout(reason, message string) *ApplicationError {
|
||||
return New(http.StatusGatewayTimeout, reason, message)
|
||||
}
|
||||
|
||||
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
|
||||
// It supports wrapped errors.
|
||||
func IsGatewayTimeout(err error) bool {
|
||||
return Code(err) == http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
|
||||
func ClientClosed(reason, message string) *ApplicationError {
|
||||
return New(499, reason, message)
|
||||
}
|
||||
|
||||
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
|
||||
// It supports wrapped errors.
|
||||
func IsClientClosed(err error) bool {
|
||||
return Code(err) == 499
|
||||
}
|
||||
157
backend/internal/pkg/geminicli/drive_client.go
Normal file
157
backend/internal/pkg/geminicli/drive_client.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
// DriveStorageInfo represents Google Drive storage quota information
|
||||
type DriveStorageInfo struct {
|
||||
Limit int64 `json:"limit"` // Storage limit in bytes
|
||||
Usage int64 `json:"usage"` // Current usage in bytes
|
||||
}
|
||||
|
||||
// DriveClient interface for Google Drive API operations
|
||||
type DriveClient interface {
|
||||
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
|
||||
}
|
||||
|
||||
type driveClient struct{}
|
||||
|
||||
// NewDriveClient creates a new Drive API client
|
||||
func NewDriveClient() DriveClient {
|
||||
return &driveClient{}
|
||||
}
|
||||
|
||||
// GetStorageQuota fetches storage quota from Google Drive API
|
||||
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
|
||||
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
// Get HTTP client with proxy support
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
sleepWithContext := func(d time.Duration) error {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
|
||||
var resp *http.Response
|
||||
maxRetries := 3
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
// Network error retry
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||
if err := sleepWithContext(backoff + jitter); err != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Success
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
break
|
||||
}
|
||||
|
||||
// Retry 429, 500, 502, 503 with exponential backoff + jitter
|
||||
if (resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusInternalServerError ||
|
||||
resp.StatusCode == http.StatusBadGateway ||
|
||||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
|
||||
if err := func() error {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||
return sleepWithContext(backoff + jitter)
|
||||
}(); err != nil {
|
||||
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("request failed: no response received")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = resp.Body.Close()
|
||||
statusText := http.StatusText(resp.StatusCode)
|
||||
if statusText == "" {
|
||||
statusText = resp.Status
|
||||
}
|
||||
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
|
||||
// 只返回通用错误
|
||||
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Parse response
|
||||
var result struct {
|
||||
StorageQuota struct {
|
||||
Limit string `json:"limit"` // Can be string or number
|
||||
Usage string `json:"usage"`
|
||||
} `json:"storageQuota"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
// Parse limit and usage (handle both string and number formats)
|
||||
var limit, usage int64
|
||||
if result.StorageQuota.Limit != "" {
|
||||
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
|
||||
limit = val
|
||||
}
|
||||
}
|
||||
if result.StorageQuota.Usage != "" {
|
||||
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
|
||||
usage = val
|
||||
}
|
||||
}
|
||||
|
||||
return &DriveStorageInfo{
|
||||
Limit: limit,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
18
backend/internal/pkg/geminicli/drive_client_test.go
Normal file
18
backend/internal/pkg/geminicli/drive_client_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package geminicli
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDriveStorageInfo(t *testing.T) {
|
||||
// 测试 DriveStorageInfo 结构体
|
||||
info := &DriveStorageInfo{
|
||||
Limit: 100 * 1024 * 1024 * 1024, // 100GB
|
||||
Usage: 50 * 1024 * 1024 * 1024, // 50GB
|
||||
}
|
||||
|
||||
if info.Limit != 100*1024*1024*1024 {
|
||||
t.Errorf("Expected limit 100GB, got %d", info.Limit)
|
||||
}
|
||||
if info.Usage != 50*1024*1024*1024 {
|
||||
t.Errorf("Expected usage 50GB, got %d", info.Usage)
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: infraerrors.NotFound("NOT_FOUND", "not found"),
|
||||
err: errors2.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: infraerrors.Conflict("CONFLICT", "conflict"),
|
||||
err: errors2.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) {
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
Message: errors2.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user