feat: 品牌重命名 Sub2API -> TianShuAPI
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled

- 前端: 所有界面显示、i18n 文本、组件中的品牌名称
- 后端: 服务层、设置默认值、邮件模板、安装向导
- 数据库: 迁移脚本注释
- 保持功能完全一致,仅更改品牌名称

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
huangzhenpc
2026-01-04 17:50:29 +08:00
parent e27c1acf79
commit d274c8cb14
417 changed files with 112280 additions and 112280 deletions

View File

@@ -1,228 +1,228 @@
package antigravity
import "encoding/json"
// Claude 请求/响应类型定义
// ClaudeRequest Claude Messages API 请求
type ClaudeRequest struct {
Model string `json:"model"`
Messages []ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
System json.RawMessage `json:"system,omitempty"` // string 或 []SystemBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
Tools []ClaudeTool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
}
// ClaudeMessage Claude 消息
type ClaudeMessage struct {
Role string `json:"role"` // user, assistant
Content json.RawMessage `json:"content"`
}
// ThinkingConfig Thinking 配置
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
}
// ClaudeMetadata 请求元数据
type ClaudeMetadata struct {
UserID string `json:"user_id,omitempty"`
}
// ClaudeTool Claude 工具定义
// 支持两种格式:
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
type ClaudeTool struct {
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
Name string `json:"name"`
Description string `json:"description,omitempty"` // 标准格式使用
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
}
// CustomToolSpec MCP custom 工具规格
type CustomToolSpec struct {
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`
Text string `json:"text"`
}
// ContentBlock Claude 消息内容块(解析后)
type ContentBlock struct {
Type string `json:"type"`
// text
Text string `json:"text,omitempty"`
// thinking
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// tool_result
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
IsError bool `json:"is_error,omitempty"`
// image
Source *ImageSource `json:"source,omitempty"`
}
// ImageSource Claude 图片来源
type ImageSource struct {
Type string `json:"type"` // "base64"
MediaType string `json:"media_type"` // "image/png", "image/jpeg" 等
Data string `json:"data"`
}
// ClaudeResponse Claude Messages API 响应
type ClaudeResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ClaudeContentItem `json:"content"`
StopReason string `json:"stop_reason,omitempty"` // end_turn, tool_use, max_tokens
StopSequence *string `json:"stop_sequence,omitempty"` // null 或具体值
Usage ClaudeUsage `json:"usage"`
}
// ClaudeContentItem Claude 响应内容项
type ClaudeContentItem struct {
Type string `json:"type"` // text, thinking, tool_use
// text
Text string `json:"text,omitempty"`
// thinking
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
}
// ClaudeUsage Claude 用量统计
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}
// ClaudeError Claude 错误响应
type ClaudeError struct {
Type string `json:"type"` // "error"
Error ErrorDetail `json:"error"`
}
// ErrorDetail 错误详情
type ErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`
}
// modelDef Antigravity 模型定义(内部使用)
type modelDef struct {
ID string
DisplayName string
CreatedAt string // 仅 Claude API 格式使用
}
// Antigravity 支持的 Claude 模型
var claudeModels = []modelDef{
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
}
// Antigravity 支持的 Gemini 模型
var geminiModels = []modelDef{
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}
// ========== Claude API 格式 (/v1/models) ==========
// ClaudeModel Claude API 模型格式
type ClaudeModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels 返回 Claude API 格式的模型列表Claude + Gemini
func DefaultModels() []ClaudeModel {
all := append(claudeModels, geminiModels...)
result := make([]ClaudeModel, len(all))
for i, m := range all {
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
}
return result
}
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
// GeminiModel Gemini v1beta 模型格式
type GeminiModel struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
// GeminiModelsListResponse Gemini v1beta 模型列表响应
type GeminiModelsListResponse struct {
Models []GeminiModel `json:"models"`
}
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
func DefaultGeminiModels() []GeminiModel {
result := make([]GeminiModel, len(geminiModels))
for i, m := range geminiModels {
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
}
return result
}
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
func FallbackGeminiModelsList() GeminiModelsListResponse {
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
}
// FallbackGeminiModel 返回单个模型信息v1beta 格式)
func FallbackGeminiModel(model string) GeminiModel {
if model == "" {
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
}
name := model
if len(model) < 7 || model[:7] != "models/" {
name = "models/" + model
}
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
}
package antigravity
import "encoding/json"
// Claude 请求/响应类型定义
// ClaudeRequest Claude Messages API 请求
type ClaudeRequest struct {
Model string `json:"model"`
Messages []ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
System json.RawMessage `json:"system,omitempty"` // string 或 []SystemBlock
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
Tools []ClaudeTool `json:"tools,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Metadata *ClaudeMetadata `json:"metadata,omitempty"`
}
// ClaudeMessage Claude 消息
type ClaudeMessage struct {
Role string `json:"role"` // user, assistant
Content json.RawMessage `json:"content"`
}
// ThinkingConfig Thinking 配置
type ThinkingConfig struct {
Type string `json:"type"` // "enabled" or "disabled"
BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget
}
// ClaudeMetadata 请求元数据
type ClaudeMetadata struct {
UserID string `json:"user_id,omitempty"`
}
// ClaudeTool Claude 工具定义
// 支持两种格式:
// 1. 标准格式: { "name": "...", "description": "...", "input_schema": {...} }
// 2. Custom 格式 (MCP): { "type": "custom", "name": "...", "custom": { "description": "...", "input_schema": {...} } }
type ClaudeTool struct {
Type string `json:"type,omitempty"` // "custom" 或空(标准格式)
Name string `json:"name"`
Description string `json:"description,omitempty"` // 标准格式使用
InputSchema map[string]any `json:"input_schema,omitempty"` // 标准格式使用
Custom *CustomToolSpec `json:"custom,omitempty"` // custom 格式使用
}
// CustomToolSpec MCP custom 工具规格
type CustomToolSpec struct {
Description string `json:"description,omitempty"`
InputSchema map[string]any `json:"input_schema"`
}
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素
type SystemBlock struct {
Type string `json:"type"`
Text string `json:"text"`
}
// ContentBlock Claude 消息内容块(解析后)
type ContentBlock struct {
Type string `json:"type"`
// text
Text string `json:"text,omitempty"`
// thinking
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
// tool_result
ToolUseID string `json:"tool_use_id,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
IsError bool `json:"is_error,omitempty"`
// image
Source *ImageSource `json:"source,omitempty"`
}
// ImageSource Claude 图片来源
type ImageSource struct {
Type string `json:"type"` // "base64"
MediaType string `json:"media_type"` // "image/png", "image/jpeg" 等
Data string `json:"data"`
}
// ClaudeResponse Claude Messages API 响应
type ClaudeResponse struct {
ID string `json:"id"`
Type string `json:"type"` // "message"
Role string `json:"role"` // "assistant"
Model string `json:"model"`
Content []ClaudeContentItem `json:"content"`
StopReason string `json:"stop_reason,omitempty"` // end_turn, tool_use, max_tokens
StopSequence *string `json:"stop_sequence,omitempty"` // null 或具体值
Usage ClaudeUsage `json:"usage"`
}
// ClaudeContentItem Claude 响应内容项
type ClaudeContentItem struct {
Type string `json:"type"` // text, thinking, tool_use
// text
Text string `json:"text,omitempty"`
// thinking
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
// tool_use
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
}
// ClaudeUsage Claude 用量统计
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}
// ClaudeError Claude 错误响应
type ClaudeError struct {
Type string `json:"type"` // "error"
Error ErrorDetail `json:"error"`
}
// ErrorDetail 错误详情
type ErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`
}
// modelDef Antigravity 模型定义(内部使用)
type modelDef struct {
ID string
DisplayName string
CreatedAt string // 仅 Claude API 格式使用
}
// Antigravity 支持的 Claude 模型
var claudeModels = []modelDef{
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
}
// Antigravity 支持的 Gemini 模型
var geminiModels = []modelDef{
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}
// ========== Claude API 格式 (/v1/models) ==========
// ClaudeModel Claude API 模型格式
type ClaudeModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels 返回 Claude API 格式的模型列表Claude + Gemini
func DefaultModels() []ClaudeModel {
all := append(claudeModels, geminiModels...)
result := make([]ClaudeModel, len(all))
for i, m := range all {
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
}
return result
}
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
// GeminiModel Gemini v1beta 模型格式
type GeminiModel struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
// GeminiModelsListResponse Gemini v1beta 模型列表响应
type GeminiModelsListResponse struct {
Models []GeminiModel `json:"models"`
}
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
func DefaultGeminiModels() []GeminiModel {
result := make([]GeminiModel, len(geminiModels))
for i, m := range geminiModels {
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
}
return result
}
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
func FallbackGeminiModelsList() GeminiModelsListResponse {
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
}
// FallbackGeminiModel 返回单个模型信息v1beta 格式)
func FallbackGeminiModel(model string) GeminiModel {
if model == "" {
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
}
name := model
if len(model) < 7 || model[:7] != "models/" {
name = "models/" + model
}
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
}

View File

@@ -1,327 +1,327 @@
package antigravity
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// NewAPIRequest 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
return req, nil
}
// TokenResponse Google OAuth token 响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// UserInfo Google 用户信息
type UserInfo struct {
Email string `json:"email"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Picture string `json:"picture,omitempty"`
}
// LoadCodeAssistRequest loadCodeAssist 请求
type LoadCodeAssistRequest struct {
Metadata struct {
IDEType string `json:"ideType"`
} `json:"metadata"`
}
// TierInfo 账户类型信息
type TierInfo struct {
ID string `json:"id"` // free-tier, g1-pro-tier, g1-ultra-tier
Name string `json:"name"` // 显示名称
Description string `json:"description"` // 描述
}
// IneligibleTier 不符合条件的层级信息
type IneligibleTier struct {
Tier *TierInfo `json:"tier,omitempty"`
// ReasonCode 不符合条件的原因代码,如 INELIGIBLE_ACCOUNT
ReasonCode string `json:"reasonCode,omitempty"`
ReasonMessage string `json:"reasonMessage,omitempty"`
}
// LoadCodeAssistResponse loadCodeAssist 响应
type LoadCodeAssistResponse struct {
CloudAICompanionProject string `json:"cloudaicompanionProject"`
CurrentTier *TierInfo `json:"currentTier,omitempty"`
PaidTier *TierInfo `json:"paidTier,omitempty"`
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// GetTier 获取账户类型
// 优先返回 paidTier付费订阅级别否则返回 currentTier
func (r *LoadCodeAssistResponse) GetTier() string {
if r.PaidTier != nil && r.PaidTier.ID != "" {
return r.PaidTier.ID
}
if r.CurrentTier != nil {
return r.CurrentTier.ID
}
return ""
}
// Client Antigravity API 客户端
type Client struct {
httpClient *http.Client
}
func NewClient(proxyURL string) *Client {
client := &http.Client{
Timeout: 30 * time.Second,
}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
}
}
}
return &Client{
httpClient: client,
}
}
// ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("code", code)
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
params.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token 交换请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("token 解析失败: %w", err)
}
return &tokenResp, nil
}
// RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("token 解析失败: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取用户信息
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("用户信息请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取用户信息失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var userInfo UserInfo
if err := json.Unmarshal(bodyBytes, &userInfo); err != nil {
return nil, fmt.Errorf("用户信息解析失败: %w", err)
}
return &userInfo, nil
}
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
reqBody := LoadCodeAssistRequest{}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := BaseURL + "/v1internal:loadCodeAssist"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}
var loadResp LoadCodeAssistResponse
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
}
// 解析原始 JSON 为 map
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
return &loadResp, rawResp, nil
}
// ModelQuotaInfo 模型配额信息
type ModelQuotaInfo struct {
RemainingFraction float64 `json:"remainingFraction"`
ResetTime string `json:"resetTime,omitempty"`
}
// ModelInfo 模型信息
type ModelInfo struct {
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
}
// FetchAvailableModelsRequest fetchAvailableModels 请求
type FetchAvailableModelsRequest struct {
Project string `json:"project"`
}
// FetchAvailableModelsResponse fetchAvailableModels 响应
type FetchAvailableModelsResponse struct {
Models map[string]ModelInfo `json:"models"`
}
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}
var modelsResp FetchAvailableModelsResponse
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
}
// 解析原始 JSON 为 map
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
return &modelsResp, rawResp, nil
}
package antigravity
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// NewAPIRequest 创建 Antigravity API 请求v1internal 端点)
func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) {
apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", UserAgent)
return req, nil
}
// TokenResponse Google OAuth token 响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// UserInfo Google 用户信息
type UserInfo struct {
Email string `json:"email"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Picture string `json:"picture,omitempty"`
}
// LoadCodeAssistRequest loadCodeAssist 请求
type LoadCodeAssistRequest struct {
Metadata struct {
IDEType string `json:"ideType"`
} `json:"metadata"`
}
// TierInfo 账户类型信息
type TierInfo struct {
ID string `json:"id"` // free-tier, g1-pro-tier, g1-ultra-tier
Name string `json:"name"` // 显示名称
Description string `json:"description"` // 描述
}
// IneligibleTier 不符合条件的层级信息
type IneligibleTier struct {
Tier *TierInfo `json:"tier,omitempty"`
// ReasonCode 不符合条件的原因代码,如 INELIGIBLE_ACCOUNT
ReasonCode string `json:"reasonCode,omitempty"`
ReasonMessage string `json:"reasonMessage,omitempty"`
}
// LoadCodeAssistResponse loadCodeAssist 响应
type LoadCodeAssistResponse struct {
CloudAICompanionProject string `json:"cloudaicompanionProject"`
CurrentTier *TierInfo `json:"currentTier,omitempty"`
PaidTier *TierInfo `json:"paidTier,omitempty"`
IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"`
}
// GetTier 获取账户类型
// 优先返回 paidTier付费订阅级别否则返回 currentTier
func (r *LoadCodeAssistResponse) GetTier() string {
if r.PaidTier != nil && r.PaidTier.ID != "" {
return r.PaidTier.ID
}
if r.CurrentTier != nil {
return r.CurrentTier.ID
}
return ""
}
// Client Antigravity API 客户端
type Client struct {
httpClient *http.Client
}
func NewClient(proxyURL string) *Client {
client := &http.Client{
Timeout: 30 * time.Second,
}
if strings.TrimSpace(proxyURL) != "" {
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURLParsed),
}
}
}
return &Client{
httpClient: client,
}
}
// ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("code", code)
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
params.Set("code_verifier", codeVerifier)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token 交换请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("token 解析失败: %w", err)
}
return &tokenResp, nil
}
// RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("token 解析失败: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取用户信息
func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("用户信息请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("获取用户信息失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
}
var userInfo UserInfo
if err := json.Unmarshal(bodyBytes, &userInfo); err != nil {
return nil, fmt.Errorf("用户信息解析失败: %w", err)
}
return &userInfo, nil
}
// LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON
func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) {
reqBody := LoadCodeAssistRequest{}
reqBody.Metadata.IDEType = "ANTIGRAVITY"
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := BaseURL + "/v1internal:loadCodeAssist"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}
var loadResp LoadCodeAssistResponse
if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil {
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
}
// 解析原始 JSON 为 map
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
return &loadResp, rawResp, nil
}
// ModelQuotaInfo 模型配额信息
type ModelQuotaInfo struct {
RemainingFraction float64 `json:"remainingFraction"`
ResetTime string `json:"resetTime,omitempty"`
}
// ModelInfo 模型信息
type ModelInfo struct {
QuotaInfo *ModelQuotaInfo `json:"quotaInfo,omitempty"`
}
// FetchAvailableModelsRequest fetchAvailableModels 请求
type FetchAvailableModelsRequest struct {
Project string `json:"project"`
}
// FetchAvailableModelsResponse fetchAvailableModels 响应
type FetchAvailableModelsResponse struct {
Models map[string]ModelInfo `json:"models"`
}
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
apiURL := BaseURL + "/v1internal:fetchAvailableModels"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", UserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes))
}
var modelsResp FetchAvailableModelsResponse
if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil {
return nil, nil, fmt.Errorf("响应解析失败: %w", err)
}
// 解析原始 JSON 为 map
var rawResp map[string]any
_ = json.Unmarshal(respBodyBytes, &rawResp)
return &modelsResp, rawResp, nil
}

View File

@@ -1,168 +1,168 @@
package antigravity
// Gemini v1internal 请求/响应类型定义
// V1InternalRequest v1internal 请求包装
type V1InternalRequest struct {
Project string `json:"project"`
RequestID string `json:"requestId"`
UserAgent string `json:"userAgent"`
RequestType string `json:"requestType,omitempty"`
Model string `json:"model"`
Request GeminiRequest `json:"request"`
}
// GeminiRequest Gemini 请求内容
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
SessionID string `json:"sessionId,omitempty"`
}
// GeminiContent Gemini 内容
type GeminiContent struct {
Role string `json:"role"` // user, model
Parts []GeminiPart `json:"parts"`
}
// GeminiPart Gemini 内容部分
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
// GeminiInlineData Gemini 内联数据(图片等)
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
// GeminiFunctionCall Gemini 函数调用
type GeminiFunctionCall struct {
Name string `json:"name"`
Args any `json:"args,omitempty"`
ID string `json:"id,omitempty"`
}
// GeminiFunctionResponse Gemini 函数响应
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response map[string]any `json:"response"`
ID string `json:"id,omitempty"`
}
// GeminiGenerationConfig Gemini 生成配置
type GeminiGenerationConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// GeminiThinkingConfig Gemini thinking 配置
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts"`
ThinkingBudget int `json:"thinkingBudget,omitempty"`
}
// GeminiToolDeclaration Gemini 工具声明
type GeminiToolDeclaration struct {
FunctionDeclarations []GeminiFunctionDecl `json:"functionDeclarations,omitempty"`
GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"`
}
// GeminiFunctionDecl Gemini 函数声明
type GeminiFunctionDecl struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
}
// GeminiGoogleSearch Gemini Google 搜索工具
type GeminiGoogleSearch struct {
EnhancedContent *GeminiEnhancedContent `json:"enhancedContent,omitempty"`
}
// GeminiEnhancedContent 增强内容配置
type GeminiEnhancedContent struct {
ImageSearch *GeminiImageSearch `json:"imageSearch,omitempty"`
}
// GeminiImageSearch 图片搜索配置
type GeminiImageSearch struct {
MaxResultCount int `json:"maxResultCount,omitempty"`
}
// GeminiToolConfig Gemini 工具配置
type GeminiToolConfig struct {
FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
}
// GeminiFunctionCallingConfig 函数调用配置
type GeminiFunctionCallingConfig struct {
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
}
// GeminiSafetySetting Gemini 安全设置
type GeminiSafetySetting struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
// V1InternalResponse v1internal 响应包装
type V1InternalResponse struct {
Response GeminiResponse `json:"response"`
ResponseID string `json:"responseId,omitempty"`
ModelVersion string `json:"modelVersion,omitempty"`
}
// GeminiResponse Gemini 响应
type GeminiResponse struct {
Candidates []GeminiCandidate `json:"candidates,omitempty"`
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
ResponseID string `json:"responseId,omitempty"`
ModelVersion string `json:"modelVersion,omitempty"`
}
// GeminiCandidate Gemini 候选响应
type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
}
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: "OFF"},
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: "OFF"},
}
// DefaultStopSequences 默认停止序列
var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}
package antigravity
// Gemini v1internal 请求/响应类型定义
// V1InternalRequest v1internal 请求包装
type V1InternalRequest struct {
Project string `json:"project"`
RequestID string `json:"requestId"`
UserAgent string `json:"userAgent"`
RequestType string `json:"requestType,omitempty"`
Model string `json:"model"`
Request GeminiRequest `json:"request"`
}
// GeminiRequest Gemini 请求内容
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
SessionID string `json:"sessionId,omitempty"`
}
// GeminiContent Gemini 内容
type GeminiContent struct {
Role string `json:"role"` // user, model
Parts []GeminiPart `json:"parts"`
}
// GeminiPart Gemini 内容部分
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
// GeminiInlineData Gemini 内联数据(图片等)
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
// GeminiFunctionCall Gemini 函数调用
type GeminiFunctionCall struct {
Name string `json:"name"`
Args any `json:"args,omitempty"`
ID string `json:"id,omitempty"`
}
// GeminiFunctionResponse Gemini 函数响应
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response map[string]any `json:"response"`
ID string `json:"id,omitempty"`
}
// GeminiGenerationConfig Gemini 生成配置
type GeminiGenerationConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// GeminiThinkingConfig Gemini thinking 配置
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts"`
ThinkingBudget int `json:"thinkingBudget,omitempty"`
}
// GeminiToolDeclaration Gemini 工具声明
type GeminiToolDeclaration struct {
FunctionDeclarations []GeminiFunctionDecl `json:"functionDeclarations,omitempty"`
GoogleSearch *GeminiGoogleSearch `json:"googleSearch,omitempty"`
}
// GeminiFunctionDecl Gemini 函数声明
type GeminiFunctionDecl struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
}
// GeminiGoogleSearch Gemini Google 搜索工具
type GeminiGoogleSearch struct {
EnhancedContent *GeminiEnhancedContent `json:"enhancedContent,omitempty"`
}
// GeminiEnhancedContent 增强内容配置
type GeminiEnhancedContent struct {
ImageSearch *GeminiImageSearch `json:"imageSearch,omitempty"`
}
// GeminiImageSearch 图片搜索配置
type GeminiImageSearch struct {
MaxResultCount int `json:"maxResultCount,omitempty"`
}
// GeminiToolConfig Gemini 工具配置
type GeminiToolConfig struct {
FunctionCallingConfig *GeminiFunctionCallingConfig `json:"functionCallingConfig,omitempty"`
}
// GeminiFunctionCallingConfig 函数调用配置
type GeminiFunctionCallingConfig struct {
Mode string `json:"mode,omitempty"` // VALIDATED, AUTO, NONE
}
// GeminiSafetySetting Gemini 安全设置
type GeminiSafetySetting struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
// V1InternalResponse v1internal 响应包装
type V1InternalResponse struct {
Response GeminiResponse `json:"response"`
ResponseID string `json:"responseId,omitempty"`
ModelVersion string `json:"modelVersion,omitempty"`
}
// GeminiResponse Gemini 响应
type GeminiResponse struct {
Candidates []GeminiCandidate `json:"candidates,omitempty"`
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
ResponseID string `json:"responseId,omitempty"`
ModelVersion string `json:"modelVersion,omitempty"`
}
// GeminiCandidate Gemini 候选响应
type GeminiCandidate struct {
Content *GeminiContent `json:"content,omitempty"`
FinishReason string `json:"finishReason,omitempty"`
Index int `json:"index,omitempty"`
}
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
var DefaultSafetySettings = []GeminiSafetySetting{
{Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: "OFF"},
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: "OFF"},
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: "OFF"},
}
// DefaultStopSequences 默认停止序列
var DefaultStopSequences = []string{
"<|user|>",
"<|endoftext|>",
"<|end_of_turn|>",
"[DONE]",
"\n\nHuman:",
}

View File

@@ -1,200 +1,200 @@
package antigravity
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
const (
// Google OAuth 端点
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
// 固定的 redirect_uri用户需手动复制 code
RedirectURI = "http://localhost:8085/callback"
// OAuth scopes
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
"https://www.googleapis.com/auth/userinfo.email " +
"https://www.googleapis.com/auth/userinfo.profile " +
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
// API 端点
BaseURL = "https://cloudcode-pa.googleapis.com"
// User-Agent
UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
)
// OAuthSession 保存 OAuth 授权流程的临时状态
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore OAuth session 存储
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
func BuildAuthorizationURL(state, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("response_type", "code")
params.Set("scope", Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// GenerateMockProjectID 生成随机 project_id当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func GenerateMockProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randBytes, _ := GenerateRandomBytes(7)
adj := adjectives[int(randBytes[0])%len(adjectives)]
noun := nouns[int(randBytes[1])%len(nouns)]
// 生成 5 位随机字符a-z0-9
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
suffix := make([]byte, 5)
for i := 0; i < 5; i++ {
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
}
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
}
package antigravity
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
const (
// Google OAuth 端点
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
// 固定的 redirect_uri用户需手动复制 code
RedirectURI = "http://localhost:8085/callback"
// OAuth scopes
Scopes = "https://www.googleapis.com/auth/cloud-platform " +
"https://www.googleapis.com/auth/userinfo.email " +
"https://www.googleapis.com/auth/userinfo.profile " +
"https://www.googleapis.com/auth/cclog " +
"https://www.googleapis.com/auth/experimentsandconfigs"
// API 端点
BaseURL = "https://cloudcode-pa.googleapis.com"
// User-Agent
UserAgent = "antigravity/1.11.9 windows/amd64"
// Session 过期时间
SessionTTL = 30 * time.Minute
)
// OAuthSession 保存 OAuth 授权流程的临时状态
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore OAuth session 存储
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// BuildAuthorizationURL 构建 Google OAuth 授权 URL
func BuildAuthorizationURL(state, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("response_type", "code")
params.Set("scope", Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// GenerateMockProjectID 生成随机 project_id当 API 不返回时使用)
// 格式:{形容词}-{名词}-{5位随机字符}
func GenerateMockProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randBytes, _ := GenerateRandomBytes(7)
adj := adjectives[int(randBytes[0])%len(adjectives)]
noun := nouns[int(randBytes[1])%len(nouns)]
// 生成 5 位随机字符a-z0-9
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
suffix := make([]byte, 5)
for i := 0; i < 5; i++ {
suffix[i] = charset[int(randBytes[i+2])%len(charset)]
}
return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,179 +1,179 @@
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &ClaudeCustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &ClaudeCustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &ClaudeCustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &ClaudeCustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &ClaudeCustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &ClaudeCustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}

View File

@@ -1,273 +1,273 @@
package antigravity
import (
"encoding/json"
"fmt"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *ClaudeUsage, error) {
// 解包 v1internal 响应
var v1Resp V1InternalResponse
if err := json.Unmarshal(geminiResp, &v1Resp); err != nil {
// 尝试直接解析为 GeminiResponse
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response: %w", err)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
processor := NewNonStreamingProcessor()
claudeResp := processor.Process(&v1Resp.Response, v1Resp.ResponseID, originalModel)
// 序列化
respBytes, err := json.Marshal(claudeResp)
if err != nil {
return nil, nil, fmt.Errorf("marshal claude response: %w", err)
}
return respBytes, &claudeResp.Usage, nil
}
// NonStreamingProcessor 非流式响应处理器
type NonStreamingProcessor struct {
contentBlocks []ClaudeContentItem
textBuilder string
thinkingBuilder string
thinkingSignature string
trailingSignature string
hasToolCall bool
}
// NewNonStreamingProcessor 创建非流式响应处理器
func NewNonStreamingProcessor() *NonStreamingProcessor {
return &NonStreamingProcessor{
contentBlocks: make([]ClaudeContentItem, 0),
}
}
// Process 处理 Gemini 响应
func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
// 获取 parts
var parts []GeminiPart
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
parts = geminiResp.Candidates[0].Content.Parts
}
// 处理所有 parts
for _, part := range parts {
p.processPart(&part)
}
// 刷新剩余内容
p.flushThinking()
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
}
// 构建响应
return p.buildResponse(geminiResp, responseID, originalModel)
}
// processPart 处理单个 part
func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
signature := part.ThoughtSignature
// 1. FunctionCall 处理
if part.FunctionCall != nil {
p.flushThinking()
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.hasToolCall = true
// 生成 tool_use id
toolID := part.FunctionCall.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
}
item := ClaudeContentItem{
Type: "tool_use",
ID: toolID,
Name: part.FunctionCall.Name,
Input: part.FunctionCall.Args,
}
if signature != "" {
item.Signature = signature
}
p.contentBlocks = append(p.contentBlocks, item)
return
}
// 2. Text 处理
if part.Text != "" || part.Thought {
if part.Thought {
// Thinking part
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.flushThinking()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.thinkingBuilder += part.Text
if signature != "" {
p.thinkingSignature = signature
}
} else {
// 普通 Text
if part.Text == "" {
// 空 text 带签名 - 暂存
if signature != "" {
p.trailingSignature = signature
}
return
}
p.flushThinking()
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
}
}
}
// 3. InlineData (Image) 处理
if part.InlineData != nil && part.InlineData.Data != "" {
p.flushThinking()
markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)",
part.InlineData.MimeType, part.InlineData.Data)
p.textBuilder += markdownImg
p.flushText()
}
}
// flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" {
return
}
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: p.textBuilder,
})
p.textBuilder = ""
}
// flushThinking 刷新 thinking builder
func (p *NonStreamingProcessor) flushThinking() {
if p.thinkingBuilder == "" && p.thinkingSignature == "" {
return
}
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: p.thinkingBuilder,
Signature: p.thinkingSignature,
})
p.thinkingBuilder = ""
p.thinkingSignature = ""
}
// buildResponse 构建最终响应
func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
}
stopReason := "end_turn"
if p.hasToolCall {
stopReason = "tool_use"
} else if finishReason == "MAX_TOKENS" {
stopReason = "max_tokens"
}
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
usage := ClaudeUsage{}
if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
// 生成响应 ID
respID := responseID
if respID == "" {
respID = geminiResp.ResponseID
}
if respID == "" {
respID = "msg_" + generateRandomID()
}
return &ClaudeResponse{
ID: respID,
Type: "message",
Role: "assistant",
Model: originalModel,
Content: p.contentBlocks,
StopReason: stopReason,
Usage: usage,
}
}
// generateRandomID 生成随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12)
for i := range result {
result[i] = chars[i%len(chars)]
}
return string(result)
}
package antigravity
import (
"encoding/json"
"fmt"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *ClaudeUsage, error) {
// 解包 v1internal 响应
var v1Resp V1InternalResponse
if err := json.Unmarshal(geminiResp, &v1Resp); err != nil {
// 尝试直接解析为 GeminiResponse
var directResp GeminiResponse
if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
return nil, nil, fmt.Errorf("parse gemini response: %w", err)
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
processor := NewNonStreamingProcessor()
claudeResp := processor.Process(&v1Resp.Response, v1Resp.ResponseID, originalModel)
// 序列化
respBytes, err := json.Marshal(claudeResp)
if err != nil {
return nil, nil, fmt.Errorf("marshal claude response: %w", err)
}
return respBytes, &claudeResp.Usage, nil
}
// NonStreamingProcessor 非流式响应处理器
type NonStreamingProcessor struct {
contentBlocks []ClaudeContentItem
textBuilder string
thinkingBuilder string
thinkingSignature string
trailingSignature string
hasToolCall bool
}
// NewNonStreamingProcessor 创建非流式响应处理器
func NewNonStreamingProcessor() *NonStreamingProcessor {
return &NonStreamingProcessor{
contentBlocks: make([]ClaudeContentItem, 0),
}
}
// Process 处理 Gemini 响应
func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
// 获取 parts
var parts []GeminiPart
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
parts = geminiResp.Candidates[0].Content.Parts
}
// 处理所有 parts
for _, part := range parts {
p.processPart(&part)
}
// 刷新剩余内容
p.flushThinking()
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
}
// 构建响应
return p.buildResponse(geminiResp, responseID, originalModel)
}
// processPart 处理单个 part
func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
signature := part.ThoughtSignature
// 1. FunctionCall 处理
if part.FunctionCall != nil {
p.flushThinking()
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.hasToolCall = true
// 生成 tool_use id
toolID := part.FunctionCall.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", part.FunctionCall.Name, generateRandomID())
}
item := ClaudeContentItem{
Type: "tool_use",
ID: toolID,
Name: part.FunctionCall.Name,
Input: part.FunctionCall.Args,
}
if signature != "" {
item.Signature = signature
}
p.contentBlocks = append(p.contentBlocks, item)
return
}
// 2. Text 处理
if part.Text != "" || part.Thought {
if part.Thought {
// Thinking part
p.flushText()
// 处理 trailingSignature
if p.trailingSignature != "" {
p.flushThinking()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.thinkingBuilder += part.Text
if signature != "" {
p.thinkingSignature = signature
}
} else {
// 普通 Text
if part.Text == "" {
// 空 text 带签名 - 暂存
if signature != "" {
p.trailingSignature = signature
}
return
}
p.flushThinking()
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: p.trailingSignature,
})
p.trailingSignature = ""
}
p.textBuilder += part.Text
// 非空 text 带签名 - 立即刷新并输出空 thinking 块
if signature != "" {
p.flushText()
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
}
}
}
// 3. InlineData (Image) 处理
if part.InlineData != nil && part.InlineData.Data != "" {
p.flushThinking()
markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)",
part.InlineData.MimeType, part.InlineData.Data)
p.textBuilder += markdownImg
p.flushText()
}
}
// flushText 刷新 text builder
func (p *NonStreamingProcessor) flushText() {
if p.textBuilder == "" {
return
}
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "text",
Text: p.textBuilder,
})
p.textBuilder = ""
}
// flushThinking 刷新 thinking builder
func (p *NonStreamingProcessor) flushThinking() {
if p.thinkingBuilder == "" && p.thinkingSignature == "" {
return
}
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: p.thinkingBuilder,
Signature: p.thinkingSignature,
})
p.thinkingBuilder = ""
p.thinkingSignature = ""
}
// buildResponse 构建最终响应
func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, responseID, originalModel string) *ClaudeResponse {
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
}
stopReason := "end_turn"
if p.hasToolCall {
stopReason = "tool_use"
} else if finishReason == "MAX_TOKENS" {
stopReason = "max_tokens"
}
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
usage := ClaudeUsage{}
if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
// 生成响应 ID
respID := responseID
if respID == "" {
respID = geminiResp.ResponseID
}
if respID == "" {
respID = "msg_" + generateRandomID()
}
return &ClaudeResponse{
ID: respID,
Type: "message",
Role: "assistant",
Model: originalModel,
Content: p.contentBlocks,
StopReason: stopReason,
Usage: usage,
}
}
// generateRandomID 生成随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12)
for i := range result {
result[i] = chars[i%len(chars)]
}
return string(result)
}

View File

@@ -1,464 +1,464 @@
package antigravity
import (
"bytes"
"encoding/json"
"fmt"
"strings"
)
// BlockType 内容块类型
type BlockType int
const (
BlockTypeNone BlockType = iota
BlockTypeText
BlockTypeThinking
BlockTypeFunction
)
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
blockIndex int
messageStartSent bool
messageStopSent bool
usedTool bool
pendingSignature string
trailingSignature string
originalModel string
// 累计 usage
inputTokens int
outputTokens int
cacheReadTokens int
}
// NewStreamingProcessor 创建流式响应处理器
func NewStreamingProcessor(originalModel string) *StreamingProcessor {
return &StreamingProcessor{
blockType: BlockTypeNone,
originalModel: originalModel,
}
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data:") {
return nil
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
return nil
}
// 解包 v1internal 响应
var v1Resp V1InternalResponse
if err := json.Unmarshal([]byte(data), &v1Resp); err != nil {
// 尝试直接解析为 GeminiResponse
var directResp GeminiResponse
if err2 := json.Unmarshal([]byte(data), &directResp); err2 != nil {
return nil
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
geminiResp := &v1Resp.Response
var result bytes.Buffer
// 发送 message_start
if !p.messageStartSent {
_, _ = result.Write(p.emitMessageStart(&v1Resp))
}
// 更新 usage
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
p.cacheReadTokens = cached
}
// 处理 parts
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
for _, part := range geminiResp.Candidates[0].Content.Parts {
_, _ = result.Write(p.processPart(&part))
}
}
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}
}
return result.Bytes()
}
// Finish 结束处理,返回最终事件和用量
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
var result bytes.Buffer
if !p.messageStopSent {
_, _ = result.Write(p.emitFinish(""))
}
usage := &ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
return result.Bytes(), usage
}
// emitMessageStart 发送 message_start 事件
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
if p.messageStartSent {
return nil
}
usage := ClaudeUsage{}
if v1Resp.Response.UsageMetadata != nil {
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
responseID := v1Resp.ResponseID
if responseID == "" {
responseID = v1Resp.Response.ResponseID
}
if responseID == "" {
responseID = "msg_" + generateRandomID()
}
message := map[string]any{
"id": responseID,
"type": "message",
"role": "assistant",
"content": []any{},
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
"usage": usage,
}
event := map[string]any{
"type": "message_start",
"message": message,
}
p.messageStartSent = true
return p.formatSSE("message_start", event)
}
// processPart 处理单个 part
func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
var result bytes.Buffer
signature := part.ThoughtSignature
// 1. FunctionCall 处理
if part.FunctionCall != nil {
// 先处理 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
_, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
return result.Bytes()
}
// 2. Text 处理
if part.Text != "" || part.Thought {
if part.Thought {
_, _ = result.Write(p.processThinking(part.Text, signature))
} else {
_, _ = result.Write(p.processText(part.Text, signature))
}
}
// 3. InlineData (Image) 处理
if part.InlineData != nil && part.InlineData.Data != "" {
markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)",
part.InlineData.MimeType, part.InlineData.Data)
_, _ = result.Write(p.processText(markdownImg, ""))
}
return result.Bytes()
}
// processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
// 开始或继续 thinking 块
if p.blockType != BlockTypeThinking {
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
"type": "thinking",
"thinking": "",
}))
}
if text != "" {
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
"thinking": text,
}))
}
// 暂存签名
if signature != "" {
p.pendingSignature = signature
}
return result.Bytes()
}
// processText 处理普通 text
func (p *StreamingProcessor) processText(text, signature string) []byte {
var result bytes.Buffer
// 空 text 带签名 - 暂存
if text == "" {
if signature != "" {
p.trailingSignature = signature
}
return nil
}
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
// 非空 text 带签名 - 特殊处理
if signature != "" {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": text,
}))
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
return result.Bytes()
}
// 普通 text (无签名)
if p.blockType != BlockTypeText {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
}
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": text,
}))
return result.Bytes()
}
// processFunctionCall 处理 function call
func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signature string) []byte {
var result bytes.Buffer
p.usedTool = true
toolID := fc.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
}
toolUse := map[string]any{
"type": "tool_use",
"id": toolID,
"name": fc.Name,
"input": map[string]any{},
}
if signature != "" {
toolUse["signature"] = signature
}
_, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
// 发送 input_json_delta
if fc.Args != nil {
argsJSON, _ := json.Marshal(fc.Args)
_, _ = result.Write(p.emitDelta("input_json_delta", map[string]any{
"partial_json": string(argsJSON),
}))
}
_, _ = result.Write(p.endBlock())
return result.Bytes()
}
// startBlock 开始新的内容块
func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[string]any) []byte {
var result bytes.Buffer
if p.blockType != BlockTypeNone {
_, _ = result.Write(p.endBlock())
}
event := map[string]any{
"type": "content_block_start",
"index": p.blockIndex,
"content_block": contentBlock,
}
_, _ = result.Write(p.formatSSE("content_block_start", event))
p.blockType = blockType
return result.Bytes()
}
// endBlock 结束当前内容块
func (p *StreamingProcessor) endBlock() []byte {
if p.blockType == BlockTypeNone {
return nil
}
var result bytes.Buffer
// Thinking 块结束时发送暂存的签名
if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
"signature": p.pendingSignature,
}))
p.pendingSignature = ""
}
event := map[string]any{
"type": "content_block_stop",
"index": p.blockIndex,
}
_, _ = result.Write(p.formatSSE("content_block_stop", event))
p.blockIndex++
p.blockType = BlockTypeNone
return result.Bytes()
}
// emitDelta 发送 delta 事件
func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string]any) []byte {
delta := map[string]any{
"type": deltaType,
}
for k, v := range deltaContent {
delta[k] = v
}
event := map[string]any{
"type": "content_block_delta",
"index": p.blockIndex,
"delta": delta,
}
return p.formatSSE("content_block_delta", event)
}
// emitEmptyThinkingWithSignature 发送空 thinking 块承载签名
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
var result bytes.Buffer
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
"type": "thinking",
"thinking": "",
}))
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
"thinking": "",
}))
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
"signature": signature,
}))
_, _ = result.Write(p.endBlock())
return result.Bytes()
}
// emitFinish 发送结束事件
func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
var result bytes.Buffer
// 关闭最后一个块
_, _ = result.Write(p.endBlock())
// 处理 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
// 确定 stop_reason
stopReason := "end_turn"
if p.usedTool {
stopReason = "tool_use"
} else if finishReason == "MAX_TOKENS" {
stopReason = "max_tokens"
}
usage := ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usage,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
if !p.messageStopSent {
stopEvent := map[string]any{
"type": "message_stop",
}
_, _ = result.Write(p.formatSSE("message_stop", stopEvent))
p.messageStopSent = true
}
return result.Bytes()
}
// formatSSE 格式化 SSE 事件
func (p *StreamingProcessor) formatSSE(eventType string, data any) []byte {
jsonData, err := json.Marshal(data)
if err != nil {
return nil
}
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(jsonData)))
}
package antigravity
import (
"bytes"
"encoding/json"
"fmt"
"strings"
)
// BlockType 内容块类型
type BlockType int
const (
BlockTypeNone BlockType = iota
BlockTypeText
BlockTypeThinking
BlockTypeFunction
)
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
blockIndex int
messageStartSent bool
messageStopSent bool
usedTool bool
pendingSignature string
trailingSignature string
originalModel string
// 累计 usage
inputTokens int
outputTokens int
cacheReadTokens int
}
// NewStreamingProcessor 创建流式响应处理器
func NewStreamingProcessor(originalModel string) *StreamingProcessor {
return &StreamingProcessor{
blockType: BlockTypeNone,
originalModel: originalModel,
}
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data:") {
return nil
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" || data == "[DONE]" {
return nil
}
// 解包 v1internal 响应
var v1Resp V1InternalResponse
if err := json.Unmarshal([]byte(data), &v1Resp); err != nil {
// 尝试直接解析为 GeminiResponse
var directResp GeminiResponse
if err2 := json.Unmarshal([]byte(data), &directResp); err2 != nil {
return nil
}
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
}
geminiResp := &v1Resp.Response
var result bytes.Buffer
// 发送 message_start
if !p.messageStartSent {
_, _ = result.Write(p.emitMessageStart(&v1Resp))
}
// 更新 usage
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
if geminiResp.UsageMetadata != nil {
cached := geminiResp.UsageMetadata.CachedContentTokenCount
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
p.cacheReadTokens = cached
}
// 处理 parts
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
for _, part := range geminiResp.Candidates[0].Content.Parts {
_, _ = result.Write(p.processPart(&part))
}
}
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}
}
return result.Bytes()
}
// Finish 结束处理,返回最终事件和用量
func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
var result bytes.Buffer
if !p.messageStopSent {
_, _ = result.Write(p.emitFinish(""))
}
usage := &ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
return result.Bytes(), usage
}
// emitMessageStart 发送 message_start 事件
func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte {
if p.messageStartSent {
return nil
}
usage := ClaudeUsage{}
if v1Resp.Response.UsageMetadata != nil {
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
responseID := v1Resp.ResponseID
if responseID == "" {
responseID = v1Resp.Response.ResponseID
}
if responseID == "" {
responseID = "msg_" + generateRandomID()
}
message := map[string]any{
"id": responseID,
"type": "message",
"role": "assistant",
"content": []any{},
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
"usage": usage,
}
event := map[string]any{
"type": "message_start",
"message": message,
}
p.messageStartSent = true
return p.formatSSE("message_start", event)
}
// processPart 处理单个 part
func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
var result bytes.Buffer
signature := part.ThoughtSignature
// 1. FunctionCall 处理
if part.FunctionCall != nil {
// 先处理 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
_, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
return result.Bytes()
}
// 2. Text 处理
if part.Text != "" || part.Thought {
if part.Thought {
_, _ = result.Write(p.processThinking(part.Text, signature))
} else {
_, _ = result.Write(p.processText(part.Text, signature))
}
}
// 3. InlineData (Image) 处理
if part.InlineData != nil && part.InlineData.Data != "" {
markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)",
part.InlineData.MimeType, part.InlineData.Data)
_, _ = result.Write(p.processText(markdownImg, ""))
}
return result.Bytes()
}
// processThinking 处理 thinking
func (p *StreamingProcessor) processThinking(text, signature string) []byte {
var result bytes.Buffer
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
// 开始或继续 thinking 块
if p.blockType != BlockTypeThinking {
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
"type": "thinking",
"thinking": "",
}))
}
if text != "" {
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
"thinking": text,
}))
}
// 暂存签名
if signature != "" {
p.pendingSignature = signature
}
return result.Bytes()
}
// processText 处理普通 text
func (p *StreamingProcessor) processText(text, signature string) []byte {
var result bytes.Buffer
// 空 text 带签名 - 暂存
if text == "" {
if signature != "" {
p.trailingSignature = signature
}
return nil
}
// 处理之前的 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
// 非空 text 带签名 - 特殊处理
if signature != "" {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": text,
}))
_, _ = result.Write(p.endBlock())
_, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
return result.Bytes()
}
// 普通 text (无签名)
if p.blockType != BlockTypeText {
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{
"type": "text",
"text": "",
}))
}
_, _ = result.Write(p.emitDelta("text_delta", map[string]any{
"text": text,
}))
return result.Bytes()
}
// processFunctionCall 处理 function call
func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signature string) []byte {
var result bytes.Buffer
p.usedTool = true
toolID := fc.ID
if toolID == "" {
toolID = fmt.Sprintf("%s-%s", fc.Name, generateRandomID())
}
toolUse := map[string]any{
"type": "tool_use",
"id": toolID,
"name": fc.Name,
"input": map[string]any{},
}
if signature != "" {
toolUse["signature"] = signature
}
_, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
// 发送 input_json_delta
if fc.Args != nil {
argsJSON, _ := json.Marshal(fc.Args)
_, _ = result.Write(p.emitDelta("input_json_delta", map[string]any{
"partial_json": string(argsJSON),
}))
}
_, _ = result.Write(p.endBlock())
return result.Bytes()
}
// startBlock 开始新的内容块
func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[string]any) []byte {
var result bytes.Buffer
if p.blockType != BlockTypeNone {
_, _ = result.Write(p.endBlock())
}
event := map[string]any{
"type": "content_block_start",
"index": p.blockIndex,
"content_block": contentBlock,
}
_, _ = result.Write(p.formatSSE("content_block_start", event))
p.blockType = blockType
return result.Bytes()
}
// endBlock 结束当前内容块
func (p *StreamingProcessor) endBlock() []byte {
if p.blockType == BlockTypeNone {
return nil
}
var result bytes.Buffer
// Thinking 块结束时发送暂存的签名
if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
"signature": p.pendingSignature,
}))
p.pendingSignature = ""
}
event := map[string]any{
"type": "content_block_stop",
"index": p.blockIndex,
}
_, _ = result.Write(p.formatSSE("content_block_stop", event))
p.blockIndex++
p.blockType = BlockTypeNone
return result.Bytes()
}
// emitDelta 发送 delta 事件
func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string]any) []byte {
delta := map[string]any{
"type": deltaType,
}
for k, v := range deltaContent {
delta[k] = v
}
event := map[string]any{
"type": "content_block_delta",
"index": p.blockIndex,
"delta": delta,
}
return p.formatSSE("content_block_delta", event)
}
// emitEmptyThinkingWithSignature 发送空 thinking 块承载签名
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
var result bytes.Buffer
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]any{
"type": "thinking",
"thinking": "",
}))
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]any{
"thinking": "",
}))
_, _ = result.Write(p.emitDelta("signature_delta", map[string]any{
"signature": signature,
}))
_, _ = result.Write(p.endBlock())
return result.Bytes()
}
// emitFinish 发送结束事件
func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
var result bytes.Buffer
// 关闭最后一个块
_, _ = result.Write(p.endBlock())
// 处理 trailingSignature
if p.trailingSignature != "" {
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = ""
}
// 确定 stop_reason
stopReason := "end_turn"
if p.usedTool {
stopReason = "tool_use"
} else if finishReason == "MAX_TOKENS" {
stopReason = "max_tokens"
}
usage := ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
deltaEvent := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": usage,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
if !p.messageStopSent {
stopEvent := map[string]any{
"type": "message_stop",
}
_, _ = result.Write(p.formatSSE("message_stop", stopEvent))
p.messageStopSent = true
}
return result.Bytes()
}
// formatSSE 格式化 SSE 事件
func (p *StreamingProcessor) formatSSE(eventType string, data any) []byte {
jsonData, err := json.Marshal(data)
if err != nil {
return nil
}
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(jsonData)))
}

View File

@@ -1,80 +1,80 @@
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"
package claude
// Claude Code 客户端相关常量
// Beta header 常量
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.52.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
"X-Stainless-Runtime-Version": "v22.14.0",
"X-Stainless-Retry-Count": "0",
"X-Stainless-Timeout": "60",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
// Model 表示一个 Claude 模型
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels Claude Code 客户端支持的默认模型列表
var DefaultModels = []Model{
{
ID: "claude-opus-4-5-20251101",
Type: "model",
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
DisplayName: "Claude Sonnet 4.5",
CreatedAt: "2025-09-29T00:00:00Z",
},
{
ID: "claude-haiku-4-5-20251001",
Type: "model",
DisplayName: "Claude Haiku 4.5",
CreatedAt: "2025-10-01T00:00:00Z",
},
}
// DefaultModelIDs 返回默认模型的 ID 列表
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"

View File

@@ -1,10 +1,10 @@
// Package ctxkey 定义用于 context.Value 的类型安全 key
package ctxkey
// Key 定义 context key 的类型,避免使用内置 string 类型staticcheck SA1029
type Key string
const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
)
// Package ctxkey 定义用于 context.Value 的类型安全 key
package ctxkey
// Key 定义 context key 的类型,避免使用内置 string 类型staticcheck SA1029
type Key string
const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
)

View File

@@ -1,158 +1,158 @@
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
UnknownMessage = "internal error"
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
}
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
UnknownMessage = "internal error"
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, UnknownMessage).WithCause(err)
}

View File

@@ -1,168 +1,168 @@
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: UnknownMessage,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}

View File

@@ -1,21 +1,21 @@
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}

View File

@@ -1,114 +1,114 @@
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}

View File

@@ -1,44 +1,44 @@
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
Description string `json:"description,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
type ModelsListResponse struct {
Models []Model `json:"models"`
}
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
}
}
func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()}
}
func FallbackModel(model string) Model {
methods := []string{"generateContent", "streamGenerateContent"}
if model == "" {
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
}
if len(model) >= 7 && model[:7] == "models/" {
return Model{Name: model, SupportedGenerationMethods: methods}
}
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
}
package gemini
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
Description string `json:"description,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
type ModelsListResponse struct {
Models []Model `json:"models"`
}
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
}
}
func FallbackModelsList() ModelsListResponse {
return ModelsListResponse{Models: DefaultModels()}
}
func FallbackModel(model string) Model {
methods := []string{"generateContent", "streamGenerateContent"}
if model == "" {
return Model{Name: "models/unknown", SupportedGenerationMethods: methods}
}
if len(model) >= 7 && model[:7] == "models/" {
return Model{Name: model, SupportedGenerationMethods: methods}
}
return Model{Name: "models/" + model, SupportedGenerationMethods: methods}
}

View File

@@ -1,38 +1,38 @@
package geminicli
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
type LoadCodeAssistRequest struct {
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type LoadCodeAssistMetadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform"`
PluginType string `json:"pluginType"`
}
type LoadCodeAssistResponse struct {
CurrentTier string `json:"currentTier,omitempty"`
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
}
type AllowedTier struct {
ID string `json:"id"`
IsDefault bool `json:"isDefault,omitempty"`
}
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type OnboardUserResponse struct {
Done bool `json:"done"`
Response *OnboardUserResultData `json:"response,omitempty"`
Name string `json:"name,omitempty"`
}
type OnboardUserResultData struct {
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
}
package geminicli
// LoadCodeAssistRequest matches done-hub's internal Code Assist call.
type LoadCodeAssistRequest struct {
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type LoadCodeAssistMetadata struct {
IDEType string `json:"ideType"`
Platform string `json:"platform"`
PluginType string `json:"pluginType"`
}
type LoadCodeAssistResponse struct {
CurrentTier string `json:"currentTier,omitempty"`
CloudAICompanionProject string `json:"cloudaicompanionProject,omitempty"`
AllowedTiers []AllowedTier `json:"allowedTiers,omitempty"`
}
type AllowedTier struct {
ID string `json:"id"`
IsDefault bool `json:"isDefault,omitempty"`
}
type OnboardUserRequest struct {
TierID string `json:"tierId"`
Metadata LoadCodeAssistMetadata `json:"metadata"`
}
type OnboardUserResponse struct {
Done bool `json:"done"`
Response *OnboardUserResultData `json:"response,omitempty"`
Name string `json:"name,omitempty"`
}
type OnboardUserResultData struct {
CloudAICompanionProject any `json:"cloudaicompanionProject,omitempty"`
}

View File

@@ -1,42 +1,42 @@
package geminicli
import "time"
const (
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
// Note: You still need to register this redirect URI in your Google OAuth client
// unless you use an OAuth client type that permits localhost redirect URIs.
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
// Required by Google's Code Assist API.
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
// Reference: https://ai.google.dev/gemini-api/docs/oauth
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
// Note: Google Auth platform currently documents the OAuth scope as
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
SessionTTL = 30 * time.Minute
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
)
package geminicli
import "time"
const (
AIStudioBaseURL = "https://generativelanguage.googleapis.com"
GeminiCliBaseURL = "https://cloudcode-pa.googleapis.com"
AuthorizeURL = "https://accounts.google.com/o/oauth2/v2/auth"
TokenURL = "https://oauth2.googleapis.com/token"
// AIStudioOAuthRedirectURI is the default redirect URI used for AI Studio OAuth.
// This matches the "copy/paste callback URL" flow used by OpenAI OAuth in this project.
// Note: You still need to register this redirect URI in your Google OAuth client
// unless you use an OAuth client type that permits localhost redirect URIs.
AIStudioOAuthRedirectURI = "http://localhost:1455/auth/callback"
// DefaultScopes for Code Assist (includes cloud-platform for API access plus userinfo scopes)
// Required by Google's Code Assist API.
DefaultCodeAssistScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// DefaultScopes for AI Studio (uses generativelanguage API with OAuth)
// Reference: https://ai.google.dev/gemini-api/docs/oauth
// For regular Google accounts, supports API calls to generativelanguage.googleapis.com
// Note: Google Auth platform currently documents the OAuth scope as
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
GeminiCLIRedirectURI = "https://codeassist.google.com/authcode"
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
SessionTTL = 30 * time.Minute
// GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints.
GeminiCLIUserAgent = "GeminiCLI/0.1.5 (Windows; AMD64)"
)

View File

@@ -1,157 +1,157 @@
package geminicli
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
// DriveStorageInfo represents Google Drive storage quota information
type DriveStorageInfo struct {
Limit int64 `json:"limit"` // Storage limit in bytes
Usage int64 `json:"usage"` // Current usage in bytes
}
// DriveClient interface for Google Drive API operations
type DriveClient interface {
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
}
type driveClient struct{}
// NewDriveClient creates a new Drive API client
func NewDriveClient() DriveClient {
return &driveClient{}
}
// GetStorageQuota fetches storage quota from Google Drive API
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Get HTTP client with proxy support
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
sleepWithContext := func(d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var resp *http.Response
maxRetries := 3
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for attempt := 0; attempt < maxRetries; attempt++ {
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
}
resp, err = client.Do(req)
if err != nil {
// Network error retry
if attempt < maxRetries-1 {
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
}
// Success
if resp.StatusCode == http.StatusOK {
break
}
// Retry 429, 500, 502, 503 with exponential backoff + jitter
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusInternalServerError ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
if err := func() error {
defer func() { _ = resp.Body.Close() }()
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
return sleepWithContext(backoff + jitter)
}(); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
break
}
if resp == nil {
return nil, fmt.Errorf("request failed: no response received")
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
statusText := http.StatusText(resp.StatusCode)
if statusText == "" {
statusText = resp.Status
}
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
// 只返回通用错误
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
// Parse response
var result struct {
StorageQuota struct {
Limit string `json:"limit"` // Can be string or number
Usage string `json:"usage"`
} `json:"storageQuota"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Parse limit and usage (handle both string and number formats)
var limit, usage int64
if result.StorageQuota.Limit != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
limit = val
}
}
if result.StorageQuota.Usage != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
usage = val
}
}
return &DriveStorageInfo{
Limit: limit,
Usage: usage,
}, nil
}
package geminicli
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
// DriveStorageInfo represents Google Drive storage quota information
type DriveStorageInfo struct {
Limit int64 `json:"limit"` // Storage limit in bytes
Usage int64 `json:"usage"` // Current usage in bytes
}
// DriveClient interface for Google Drive API operations
type DriveClient interface {
GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error)
}
type driveClient struct{}
// NewDriveClient creates a new Drive API client
func NewDriveClient() DriveClient {
return &driveClient{}
}
// GetStorageQuota fetches storage quota from Google Drive API
func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*DriveStorageInfo, error) {
const driveAPIURL = "https://www.googleapis.com/drive/v3/about?fields=storageQuota"
req, err := http.NewRequestWithContext(ctx, "GET", driveAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Get HTTP client with proxy support
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 10 * time.Second,
})
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
sleepWithContext := func(d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var resp *http.Response
maxRetries := 3
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for attempt := 0; attempt < maxRetries; attempt++ {
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
}
resp, err = client.Do(req)
if err != nil {
// Network error retry
if attempt < maxRetries-1 {
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
}
// Success
if resp.StatusCode == http.StatusOK {
break
}
// Retry 429, 500, 502, 503 with exponential backoff + jitter
if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusInternalServerError ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
if err := func() error {
defer func() { _ = resp.Body.Close() }()
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
return sleepWithContext(backoff + jitter)
}(); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
break
}
if resp == nil {
return nil, fmt.Errorf("request failed: no response received")
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
statusText := http.StatusText(resp.StatusCode)
if statusText == "" {
statusText = resp.Status
}
fmt.Printf("[DriveClient] Drive API error: status=%d, msg=%s\n", resp.StatusCode, statusText)
// 只返回通用错误
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
}
defer func() { _ = resp.Body.Close() }()
// Parse response
var result struct {
StorageQuota struct {
Limit string `json:"limit"` // Can be string or number
Usage string `json:"usage"`
} `json:"storageQuota"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Parse limit and usage (handle both string and number formats)
var limit, usage int64
if result.StorageQuota.Limit != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
limit = val
}
}
if result.StorageQuota.Usage != "" {
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
usage = val
}
}
return &DriveStorageInfo{
Limit: limit,
Usage: usage,
}, nil
}

View File

@@ -1,18 +1,18 @@
package geminicli
import "testing"
func TestDriveStorageInfo(t *testing.T) {
// 测试 DriveStorageInfo 结构体
info := &DriveStorageInfo{
Limit: 100 * 1024 * 1024 * 1024, // 100GB
Usage: 50 * 1024 * 1024 * 1024, // 50GB
}
if info.Limit != 100*1024*1024*1024 {
t.Errorf("Expected limit 100GB, got %d", info.Limit)
}
if info.Usage != 50*1024*1024*1024 {
t.Errorf("Expected usage 50GB, got %d", info.Usage)
}
}
package geminicli
import "testing"
func TestDriveStorageInfo(t *testing.T) {
// 测试 DriveStorageInfo 结构体
info := &DriveStorageInfo{
Limit: 100 * 1024 * 1024 * 1024, // 100GB
Usage: 50 * 1024 * 1024 * 1024, // 50GB
}
if info.Limit != 100*1024*1024*1024 {
t.Errorf("Expected limit 100GB, got %d", info.Limit)
}
if info.Usage != 50*1024*1024*1024 {
t.Errorf("Expected usage 50GB, got %d", info.Usage)
}
}

View File

@@ -1,21 +1,21 @@
package geminicli
// Model represents a selectable Gemini model for UI/testing purposes.
// Keep JSON fields consistent with existing frontend expectations.
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
const DefaultTestModel = "gemini-3-pro-preview"
package geminicli
// Model represents a selectable Gemini model for UI/testing purposes.
// Keep JSON fields consistent with existing frontend expectations.
type Model struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
const DefaultTestModel = "gemini-3-pro-preview"

View File

@@ -1,243 +1,243 @@
package geminicli
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
type OAuthConfig struct {
ClientID string
ClientSecret string
Scopes string
}
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// EffectiveOAuthConfig returns the effective OAuth configuration.
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
//
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
//
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
// https://www.googleapis.com/auth/generative-language), which will surface as
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
effective := OAuthConfig{
ClientID: strings.TrimSpace(cfg.ClientID),
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
Scopes: strings.TrimSpace(cfg.Scopes),
}
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
if effective.Scopes != "" {
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
if effective.ClientID == "" && effective.ClientSecret == "" {
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
if oauthType == "ai_studio" {
// Built-in client can't request some AI Studio scopes (notably generative-language).
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultAIStudioScopes
}
} else {
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes
}
} else if oauthType == "ai_studio" && isBuiltinClient {
// If user overrides scopes while still using the built-in client, strip restricted scopes.
parts := strings.Fields(effective.Scopes)
filtered := make([]string, 0, len(parts))
for _, s := range parts {
if strings.Contains(s, "generative-language") {
continue
}
filtered = append(filtered, s)
}
if len(filtered) == 0 {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = strings.Join(filtered, " ")
}
}
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
if oauthType == "ai_studio" && effective.Scopes != "" {
parts := strings.Fields(effective.Scopes)
for i := range parts {
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
}
}
effective.Scopes = strings.Join(parts, " ")
}
return effective, nil
}
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
if err != nil {
return "", err
}
redirectURI = strings.TrimSpace(redirectURI)
if redirectURI == "" {
return "", fmt.Errorf("redirect_uri is required")
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", effectiveCfg.ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", effectiveCfg.Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
if strings.TrimSpace(projectID) != "" {
params.Set("project_id", strings.TrimSpace(projectID))
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
}
package geminicli
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
type OAuthConfig struct {
ClientID string
ClientSecret string
Scopes string
}
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type"` // "code_assist" 或 "ai_studio"
CreatedAt time.Time `json:"created_at"`
}
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
go store.cleanup()
return store
}
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
func (s *SessionStore) Stop() {
select {
case <-s.stopCh:
return
default:
close(s.stopCh)
}
}
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier returns an RFC 7636 compatible code verifier (43+ chars).
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
func base64URLEncode(data []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(data), "=")
}
// EffectiveOAuthConfig returns the effective OAuth configuration.
// oauthType: "code_assist" or "ai_studio" (defaults to "code_assist" if empty).
//
// If ClientID/ClientSecret is not provided, this falls back to the built-in Gemini CLI OAuth client.
//
// Note: The built-in Gemini CLI OAuth client is restricted and may reject some scopes (e.g.
// https://www.googleapis.com/auth/generative-language), which will surface as
// "restricted_client" / "Unregistered scope(s)" errors during browser authorization.
func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error) {
effective := OAuthConfig{
ClientID: strings.TrimSpace(cfg.ClientID),
ClientSecret: strings.TrimSpace(cfg.ClientSecret),
Scopes: strings.TrimSpace(cfg.Scopes),
}
// Normalize scopes: allow comma-separated input but send space-delimited scopes to Google.
if effective.Scopes != "" {
effective.Scopes = strings.Join(strings.Fields(strings.ReplaceAll(effective.Scopes, ",", " ")), " ")
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
if effective.ClientID == "" && effective.ClientSecret == "" {
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
if effective.Scopes == "" {
// Use different default scopes based on OAuth type
if oauthType == "ai_studio" {
// Built-in client can't request some AI Studio scopes (notably generative-language).
if isBuiltinClient {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultAIStudioScopes
}
} else {
// Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes
}
} else if oauthType == "ai_studio" && isBuiltinClient {
// If user overrides scopes while still using the built-in client, strip restricted scopes.
parts := strings.Fields(effective.Scopes)
filtered := make([]string, 0, len(parts))
for _, s := range parts {
if strings.Contains(s, "generative-language") {
continue
}
filtered = append(filtered, s)
}
if len(filtered) == 0 {
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = strings.Join(filtered, " ")
}
}
// Backward compatibility: normalize older AI Studio scope to the currently documented one.
if oauthType == "ai_studio" && effective.Scopes != "" {
parts := strings.Fields(effective.Scopes)
for i := range parts {
if parts[i] == "https://www.googleapis.com/auth/generative-language" {
parts[i] = "https://www.googleapis.com/auth/generative-language.retriever"
}
}
effective.Scopes = strings.Join(parts, " ")
}
return effective, nil
}
func BuildAuthorizationURL(cfg OAuthConfig, state, codeChallenge, redirectURI, projectID, oauthType string) (string, error) {
effectiveCfg, err := EffectiveOAuthConfig(cfg, oauthType)
if err != nil {
return "", err
}
redirectURI = strings.TrimSpace(redirectURI)
if redirectURI == "" {
return "", fmt.Errorf("redirect_uri is required")
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", effectiveCfg.ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", effectiveCfg.Scopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
params.Set("access_type", "offline")
params.Set("prompt", "consent")
params.Set("include_granted_scopes", "true")
if strings.TrimSpace(projectID) != "" {
params.Set("project_id", strings.TrimSpace(projectID))
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()), nil
}

View File

@@ -1,46 +1,46 @@
package geminicli
import "strings"
const maxLogBodyLen = 2048
func SanitizeBodyForLogs(body string) string {
body = truncateBase64InMessage(body)
if len(body) > maxLogBodyLen {
body = body[:maxLogBodyLen] + "...[truncated]"
}
return body
}
func truncateBase64InMessage(message string) string {
const maxBase64Length = 50
result := message
offset := 0
for {
idx := strings.Index(result[offset:], ";base64,")
if idx == -1 {
break
}
actualIdx := offset + idx
start := actualIdx + len(";base64,")
end := start
for end < len(result) && isBase64Char(result[end]) {
end++
}
if end-start > maxBase64Length {
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
offset = start + maxBase64Length + len("...[truncated]")
continue
}
offset = end
}
return result
}
func isBase64Char(c byte) bool {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
}
package geminicli
import "strings"
const maxLogBodyLen = 2048
func SanitizeBodyForLogs(body string) string {
body = truncateBase64InMessage(body)
if len(body) > maxLogBodyLen {
body = body[:maxLogBodyLen] + "...[truncated]"
}
return body
}
func truncateBase64InMessage(message string) string {
const maxBase64Length = 50
result := message
offset := 0
for {
idx := strings.Index(result[offset:], ";base64,")
if idx == -1 {
break
}
actualIdx := offset + idx
start := actualIdx + len(";base64,")
end := start
for end < len(result) && isBase64Char(result[end]) {
end++
}
if end-start > maxBase64Length {
result = result[:start+maxBase64Length] + "...[truncated]" + result[end:]
offset = start + maxBase64Length + len("...[truncated]")
continue
}
offset = end
}
return result
}
func isBase64Char(c byte) bool {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '='
}

View File

@@ -1,9 +1,9 @@
package geminicli
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}
package geminicli
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}

View File

@@ -1,24 +1,24 @@
package googleapi
import "net/http"
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
func HTTPStatusToGoogleStatus(status int) string {
switch status {
case http.StatusBadRequest:
return "INVALID_ARGUMENT"
case http.StatusUnauthorized:
return "UNAUTHENTICATED"
case http.StatusForbidden:
return "PERMISSION_DENIED"
case http.StatusNotFound:
return "NOT_FOUND"
case http.StatusTooManyRequests:
return "RESOURCE_EXHAUSTED"
default:
if status >= 500 {
return "INTERNAL"
}
return "UNKNOWN"
}
}
package googleapi
import "net/http"
// HTTPStatusToGoogleStatus maps HTTP status codes to Google-style error status strings.
func HTTPStatusToGoogleStatus(status int) string {
switch status {
case http.StatusBadRequest:
return "INVALID_ARGUMENT"
case http.StatusUnauthorized:
return "UNAUTHENTICATED"
case http.StatusForbidden:
return "PERMISSION_DENIED"
case http.StatusNotFound:
return "NOT_FOUND"
case http.StatusTooManyRequests:
return "RESOURCE_EXHAUSTED"
default:
if status >= 500 {
return "INTERNAL"
}
return "UNKNOWN"
}
}

View File

@@ -1,157 +1,157 @@
// Package httpclient 提供共享 HTTP 客户端池
//
// 性能优化说明:
// 原实现在多个服务中重复创建 http.Client
// 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端
//
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
// 4. 支持严格代理模式(代理失败则返回错误)
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
)
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
}
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
if client, ok := cached.(*http.Client); ok {
return client, nil
}
}
client, err := buildClient(opts)
if err != nil {
if opts.ProxyStrict {
return nil, err
}
fallback := opts
fallback.ProxyURL = ""
client, _ = buildClient(fallback)
}
actual, _ := sharedClients.LoadOrStore(key, client)
if c, ok := actual.(*http.Client); ok {
return c, nil
}
return client, nil
}
func buildClient(opts Options) (*http.Client, error) {
transport, err := buildTransport(opts)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}, nil
}
func buildTransport(opts Options) (*http.Transport, error) {
// 使用自定义值或默认值
maxIdleConns := opts.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
transport.Proxy = http.ProxyURL(parsed)
case "socks5", "socks5h":
dialer, err := proxy.FromURL(parsed, proxy.Direct)
if err != nil {
return nil, err
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
// Package httpclient 提供共享 HTTP 客户端池
//
// 性能优化说明:
// 原实现在多个服务中重复创建 http.Client
// 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端
//
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
// 4. 支持严格代理模式(代理失败则返回错误)
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
)
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
}
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
if client, ok := cached.(*http.Client); ok {
return client, nil
}
}
client, err := buildClient(opts)
if err != nil {
if opts.ProxyStrict {
return nil, err
}
fallback := opts
fallback.ProxyURL = ""
client, _ = buildClient(fallback)
}
actual, _ := sharedClients.LoadOrStore(key, client)
if c, ok := actual.(*http.Client); ok {
return c, nil
}
return client, nil
}
func buildClient(opts Options) (*http.Client, error) {
transport, err := buildTransport(opts)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}, nil
}
func buildTransport(opts Options) (*http.Transport, error) {
// 使用自定义值或默认值
maxIdleConns := opts.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
transport.Proxy = http.ProxyURL(parsed)
case "socks5", "socks5h":
dialer, err := proxy.FromURL(parsed, proxy.Direct)
if err != nil {
return nil, err
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}

View File

@@ -1,236 +1,236 @@
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// Claude OAuth Constants (from CRS project)
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
ScopeInference = "user:inference"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
Scope string `json:"scope"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
// OrgInfo represents organization info from OAuth response
type OrgInfo struct {
UUID string `json:"uuid"`
}
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
}
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// Claude OAuth Constants (from CRS project)
const (
// OAuth Client ID for Claude
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
// OAuth endpoints
AuthorizeURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://console.anthropic.com/v1/oauth/token"
RedirectURI = "https://console.anthropic.com/oauth/code/callback"
// Scopes
ScopeProfile = "user:profile"
ScopeInference = "user:inference"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
Scope string `json:"scope"`
ProxyURL string `json:"proxy_url,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url)
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return base64URLEncode(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, scope string) string {
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", RedirectURI)
params.Set("scope", scope)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
}
// TokenResponse represents the token response from OAuth provider
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
// Organization and Account info from OAuth response
Organization *OrgInfo `json:"organization,omitempty"`
Account *AccountInfo `json:"account,omitempty"`
}
// OrgInfo represents organization info from OAuth response
type OrgInfo struct {
UUID string `json:"uuid"`
}
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
UUID string `json:"uuid"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
}
// BuildTokenRequest creates a token exchange request
func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest {
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: RedirectURI,
CodeVerifier: codeVerifier,
State: state,
}
}
// BuildRefreshTokenRequest creates a refresh token request
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
}
}

View File

@@ -1,42 +1,42 @@
package openai
import _ "embed"
// Model represents an OpenAI model
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
}
// DefaultModelIDs returns the default model ID list
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel default model for testing OpenAI accounts
const DefaultTestModel = "gpt-5.1-codex"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
//
//go:embed instructions.txt
var DefaultInstructions string
package openai
import _ "embed"
// Model represents an OpenAI model
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
{ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
{ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
{ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
{ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
}
// DefaultModelIDs returns the default model ID list
func DefaultModelIDs() []string {
ids := make([]string, len(DefaultModels))
for i, m := range DefaultModels {
ids[i] = m.ID
}
return ids
}
// DefaultTestModel default model for testing OpenAI accounts
const DefaultTestModel = "gpt-5.1-codex"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
//
//go:embed instructions.txt
var DefaultInstructions string

View File

@@ -1,118 +1,118 @@
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
## General
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend a commit unless explicitly requested to do so.
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Plan tool
When using the planning tool:
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
- Do not make single-step plans.
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
## Codex CLI harness, sandboxing, and approvals
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
- **read-only**: The sandbox only permits reading files.
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
- **restricted**: Requires approval
- **enabled**: No approval needed
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (for all of these, you should weigh alternative paths that do not require approval)
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
When requesting approval to execute a command that will require escalated privileges:
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
## Special user requests
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
## Frontend tasks
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
Aim for interfaces that feel intentional, bold, and a bit surprising.
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
- Ensure the page loads properly on both desktop and mobile
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Ask only when needed; suggest ideas; mirror the user's style.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No \"save/copy this file\" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
### Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no \"above/below\"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Optionally include line/column (1based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
## General
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like \"Assigns the value to the variable\", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend a commit unless explicitly requested to do so.
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Plan tool
When using the planning tool:
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
- Do not make single-step plans.
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
## Codex CLI harness, sandboxing, and approvals
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
- **read-only**: The sandbox only permits reading files.
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
- **restricted**: Requires approval
- **enabled**: No approval needed
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (for all of these, you should weigh alternative paths that do not require approval)
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to \"never\", in which case never ask for approvals.
When requesting approval to execute a command that will require escalated privileges:
- Provide the `sandbox_permissions` parameter with the value `\"require_escalated\"`
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
## Special user requests
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
- If the user asks for a \"review\", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
## Frontend tasks
When doing frontend design tasks, avoid collapsing into \"AI slop\" or safe, average-looking layouts.
Aim for interfaces that feel intentional, bold, and a bit surprising.
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
- Ensure the page loads properly on both desktop and mobile
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Ask only when needed; suggest ideas; mirror the user's style.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No \"save/copy this file\" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with \"summary\", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
### Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no \"above/below\"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Optionally include line/column (1based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\\repo\\project\\main.rs:12:5

View File

@@ -1,366 +1,366 @@
package openai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
// Default redirect URI (can be customized)
DefaultRedirectURI = "http://localhost:1455/auth/callback"
// Scopes
DefaultScopes = "openid profile email offline_access"
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
RefreshScopes = "openid profile email"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
// OpenAI uses hex encoding instead of base64url
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(64)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
// Uses base64url encoding as per RFC 7636
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
}
// TokenResponse represents the token response from OpenAI OAuth
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// IDTokenClaims represents the claims from OpenAI ID Token
type IDTokenClaims struct {
// Standard claims
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Iss string `json:"iss"`
Aud []string `json:"aud"` // OpenAI returns aud as an array
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
// OpenAI specific claims (nested under https://api.openai.com/auth)
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
}
// OpenAIAuthClaims represents the OpenAI specific auth claims
type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"`
UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"`
}
// OrganizationClaim represents an organization in the ID Token
type OrganizationClaim struct {
ID string `json:"id"`
Role string `json:"role"`
Title string `json:"title"`
IsDefault bool `json:"is_default"`
}
// BuildTokenRequest creates a token exchange request for OpenAI
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: redirectURI,
CodeVerifier: codeVerifier,
}
}
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
Scope: RefreshScopes,
}
}
// ToFormData converts TokenRequest to URL-encoded form data
func (r *TokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("code", r.Code)
params.Set("redirect_uri", r.RedirectURI)
params.Set("code_verifier", r.CodeVerifier)
return params.Encode()
}
// ToFormData converts RefreshTokenRequest to URL-encoded form data
func (r *RefreshTokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("refresh_token", r.RefreshToken)
params.Set("scope", r.Scope)
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode payload (second part)
payload := parts[1]
// Add padding if necessary
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try standard encoding
decoded, err = base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
}
var claims IDTokenClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string
ChatGPTUserID string
UserID string
OrganizationID string
Organizations []OrganizationClaim
}
// GetUserInfo extracts user info from ID Token claims
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
info := &UserInfo{
Email: c.Email,
}
if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations
// Get default organization ID
for _, org := range c.OpenAIAuth.Organizations {
if org.IsDefault {
info.OrganizationID = org.ID
break
}
}
// If no default, use first org
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
}
}
return info
}
package openai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"strings"
"sync"
"time"
)
// OpenAI OAuth Constants (from CRS project - Codex CLI client)
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
// Default redirect URI (can be customized)
DefaultRedirectURI = "http://localhost:1455/auth/callback"
// Scopes
DefaultScopes = "openid profile email offline_access"
// RefreshScopes - scope for token refresh (without offline_access, aligned with CRS project)
RefreshScopes = "openid profile email"
// Session TTL
SessionTTL = 30 * time.Minute
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore manages OAuth sessions in memory
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopCh chan struct{}
}
// NewSessionStore creates a new session store
func NewSessionStore() *SessionStore {
store := &SessionStore{
sessions: make(map[string]*OAuthSession),
stopCh: make(chan struct{}),
}
// Start cleanup goroutine
go store.cleanup()
return store
}
// Set stores a session
func (s *SessionStore) Set(sessionID string, session *OAuthSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[sessionID] = session
}
// Get retrieves a session
func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[sessionID]
if !ok {
return nil, false
}
// Check if expired
if time.Since(session.CreatedAt) > SessionTTL {
return nil, false
}
return session, true
}
// Delete removes a session
func (s *SessionStore) Delete(sessionID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionID)
}
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
}
// cleanup removes expired sessions periodically
func (s *SessionStore) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
s.mu.Lock()
for id, session := range s.sessions {
if time.Since(session.CreatedAt) > SessionTTL {
delete(s.sessions, id)
}
}
s.mu.Unlock()
}
}
}
// GenerateRandomBytes generates cryptographically secure random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateState generates a random state string for OAuth
func GenerateState() (string, error) {
bytes, err := GenerateRandomBytes(32)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateSessionID generates a unique session ID
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeVerifier generates a PKCE code verifier (64 bytes -> hex for OpenAI)
// OpenAI uses hex encoding instead of base64url
func GenerateCodeVerifier() (string, error) {
bytes, err := GenerateRandomBytes(64)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// GenerateCodeChallenge generates a PKCE code challenge using S256 method
// Uses base64url encoding as per RFC 7636
func GenerateCodeChallenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
ClientID string `json:"client_id"`
Code string `json:"code"`
RedirectURI string `json:"redirect_uri"`
CodeVerifier string `json:"code_verifier"`
}
// TokenResponse represents the token response from OpenAI OAuth
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
// RefreshTokenRequest represents the refresh token request
type RefreshTokenRequest struct {
GrantType string `json:"grant_type"`
RefreshToken string `json:"refresh_token"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
}
// IDTokenClaims represents the claims from OpenAI ID Token
type IDTokenClaims struct {
// Standard claims
Sub string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Iss string `json:"iss"`
Aud []string `json:"aud"` // OpenAI returns aud as an array
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
// OpenAI specific claims (nested under https://api.openai.com/auth)
OpenAIAuth *OpenAIAuthClaims `json:"https://api.openai.com/auth,omitempty"`
}
// OpenAIAuthClaims represents the OpenAI specific auth claims
type OpenAIAuthClaims struct {
ChatGPTAccountID string `json:"chatgpt_account_id"`
ChatGPTUserID string `json:"chatgpt_user_id"`
UserID string `json:"user_id"`
Organizations []OrganizationClaim `json:"organizations"`
}
// OrganizationClaim represents an organization in the ID Token
type OrganizationClaim struct {
ID string `json:"id"`
Role string `json:"role"`
Title string `json:"title"`
IsDefault bool `json:"is_default"`
}
// BuildTokenRequest creates a token exchange request for OpenAI
func BuildTokenRequest(code, codeVerifier, redirectURI string) *TokenRequest {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
return &TokenRequest{
GrantType: "authorization_code",
ClientID: ClientID,
Code: code,
RedirectURI: redirectURI,
CodeVerifier: codeVerifier,
}
}
// BuildRefreshTokenRequest creates a refresh token request for OpenAI
func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest {
return &RefreshTokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: ClientID,
Scope: RefreshScopes,
}
}
// ToFormData converts TokenRequest to URL-encoded form data
func (r *TokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("code", r.Code)
params.Set("redirect_uri", r.RedirectURI)
params.Set("code_verifier", r.CodeVerifier)
return params.Encode()
}
// ToFormData converts RefreshTokenRequest to URL-encoded form data
func (r *RefreshTokenRequest) ToFormData() string {
params := url.Values{}
params.Set("grant_type", r.GrantType)
params.Set("client_id", r.ClientID)
params.Set("refresh_token", r.RefreshToken)
params.Set("scope", r.Scope)
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode payload (second part)
payload := parts[1]
// Add padding if necessary
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try standard encoding
decoded, err = base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
}
var claims IDTokenClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
return &claims, nil
}
// ExtractUserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string
ChatGPTUserID string
UserID string
OrganizationID string
Organizations []OrganizationClaim
}
// GetUserInfo extracts user info from ID Token claims
func (c *IDTokenClaims) GetUserInfo() *UserInfo {
info := &UserInfo{
Email: c.Email,
}
if c.OpenAIAuth != nil {
info.ChatGPTAccountID = c.OpenAIAuth.ChatGPTAccountID
info.ChatGPTUserID = c.OpenAIAuth.ChatGPTUserID
info.UserID = c.OpenAIAuth.UserID
info.Organizations = c.OpenAIAuth.Organizations
// Get default organization ID
for _, org := range c.OpenAIAuth.Organizations {
if org.IsDefault {
info.OrganizationID = org.ID
break
}
}
// If no default, use first org
if info.OrganizationID == "" && len(c.OpenAIAuth.Organizations) > 0 {
info.OrganizationID = c.OpenAIAuth.Organizations[0].ID
}
}
return info
}

View File

@@ -1,18 +1,18 @@
package openai
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
"codex_vscode/",
"codex_cli_rs/",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
return true
}
}
return false
}
package openai
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
"codex_vscode/",
"codex_cli_rs/",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
return true
}
}
return false
}

View File

@@ -1,42 +1,42 @@
package pagination
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}
package pagination
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -1,185 +1,185 @@
package response
import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items any `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data any) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Created 返回创建成功响应
func Created(c *gin.Context, data any) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: "",
Metadata: nil,
})
}
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
}
// Unauthorized 返回401错误
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, message)
}
// Forbidden 返回403错误
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, message)
}
// NotFound 返回404错误
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, message)
}
// InternalError 返回500错误
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, message)
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
Success(c, PaginatedData{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
})
}
// PaginationResult 分页结果与pagination.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,
Total: 0,
Page: 1,
PageSize: 20,
Pages: 1,
})
return
}
Success(c, PaginatedData{
Items: items,
Total: pagination.Total,
Page: pagination.Page,
PageSize: pagination.PageSize,
Pages: pagination.Pages,
})
}
// ParsePagination 解析分页参数
func ParsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p := c.Query("page"); p != "" {
if val, err := parseInt(p); err == nil && val > 0 {
page = val
}
}
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
pageSize = val
}
}
return page, pageSize
}
func parseInt(s string) (int, error) {
var result int
for _, c := range s {
if c < '0' || c > '9' {
return 0, nil
}
result = result*10 + int(c-'0')
}
return result, nil
}
package response
import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
type PaginatedData struct {
Items any `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Pages int `json:"pages"`
}
// Success 返回成功响应
func Success(c *gin.Context, data any) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Created 返回创建成功响应
func Created(c *gin.Context, data any) {
c.JSON(http.StatusCreated, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: "",
Metadata: nil,
})
}
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
}
// Unauthorized 返回401错误
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, message)
}
// Forbidden 返回403错误
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, message)
}
// NotFound 返回404错误
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, message)
}
// InternalError 返回500错误
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, message)
}
// Paginated 返回分页数据
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
pages := int(math.Ceil(float64(total) / float64(pageSize)))
if pages < 1 {
pages = 1
}
Success(c, PaginatedData{
Items: items,
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
})
}
// PaginationResult 分页结果与pagination.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// PaginatedWithResult 使用PaginationResult返回分页数据
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
if pagination == nil {
Success(c, PaginatedData{
Items: items,
Total: 0,
Page: 1,
PageSize: 20,
Pages: 1,
})
return
}
Success(c, PaginatedData{
Items: items,
Total: pagination.Total,
Page: pagination.Page,
PageSize: pagination.PageSize,
Pages: pagination.Pages,
})
}
// ParsePagination 解析分页参数
func ParsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p := c.Query("page"); p != "" {
if val, err := parseInt(p); err == nil && val > 0 {
page = val
}
}
// 支持 page_size 和 limit 两种参数名
if ps := c.Query("page_size"); ps != "" {
if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 {
pageSize = val
}
} else if l := c.Query("limit"); l != "" {
if val, err := parseInt(l); err == nil && val > 0 && val <= 100 {
pageSize = val
}
}
return page, pageSize
}
func parseInt(s string) (int, error) {
var result int
for _, c := range s {
if c < '0' || c > '9' {
return 0, nil
}
result = result*10 + int(c-'0')
}
return result, nil
}

View File

@@ -1,171 +1,171 @@
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: errors2.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: errors2.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: errors2.UnknownMessage,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: errors2.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: errors2.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: errors2.UnknownMessage,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -1,47 +1,47 @@
package sysutil
import (
"log"
"os"
"runtime"
"time"
)
// RestartService triggers a service restart by gracefully exiting.
//
// This relies on systemd's Restart=always configuration to automatically
// restart the service after it exits. This is the industry-standard approach:
// - Simple and reliable
// - No sudo permissions needed
// - No complex process management
// - Leverages systemd's native restart capability
//
// Prerequisites:
// - Linux OS with systemd
// - Service configured with Restart=always in systemd unit file
func RestartService() error {
if runtime.GOOS != "linux" {
log.Println("Service restart via exit only works on Linux with systemd")
return nil
}
log.Println("Initiating service restart by graceful exit...")
log.Println("systemd will automatically restart the service (Restart=always)")
// Give a moment for logs to flush and response to be sent
go func() {
time.Sleep(100 * time.Millisecond)
os.Exit(0)
}()
return nil
}
// RestartServiceAsync is a fire-and-forget version of RestartService.
// It logs errors instead of returning them, suitable for goroutine usage.
func RestartServiceAsync() {
if err := RestartService(); err != nil {
log.Printf("Service restart failed: %v", err)
log.Println("Please restart the service manually: sudo systemctl restart sub2api")
}
}
package sysutil
import (
"log"
"os"
"runtime"
"time"
)
// RestartService triggers a service restart by gracefully exiting.
//
// This relies on systemd's Restart=always configuration to automatically
// restart the service after it exits. This is the industry-standard approach:
// - Simple and reliable
// - No sudo permissions needed
// - No complex process management
// - Leverages systemd's native restart capability
//
// Prerequisites:
// - Linux OS with systemd
// - Service configured with Restart=always in systemd unit file
func RestartService() error {
if runtime.GOOS != "linux" {
log.Println("Service restart via exit only works on Linux with systemd")
return nil
}
log.Println("Initiating service restart by graceful exit...")
log.Println("systemd will automatically restart the service (Restart=always)")
// Give a moment for logs to flush and response to be sent
go func() {
time.Sleep(100 * time.Millisecond)
os.Exit(0)
}()
return nil
}
// RestartServiceAsync is a fire-and-forget version of RestartService.
// It logs errors instead of returning them, suitable for goroutine usage.
func RestartServiceAsync() {
if err := RestartService(); err != nil {
log.Printf("Service restart failed: %v", err)
log.Println("Please restart the service manually: sudo systemctl restart sub2api")
}
}

View File

@@ -1,124 +1,124 @@
// Package timezone provides global timezone management for the application.
// Similar to PHP's date_default_timezone_set, this package allows setting
// a global timezone that affects all time.Now() calls.
package timezone
import (
"fmt"
"log"
"time"
)
var (
// location is the global timezone location
location *time.Location
// tzName stores the timezone name for logging/debugging
tzName string
)
// Init initializes the global timezone setting.
// This should be called once at application startup.
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
func Init(tz string) error {
if tz == "" {
tz = "Asia/Shanghai" // Default timezone
}
loc, err := time.LoadLocation(tz)
if err != nil {
return fmt.Errorf("invalid timezone %q: %w", tz, err)
}
// Set the global Go time.Local to our timezone
// This affects time.Now() throughout the application
time.Local = loc
location = loc
tzName = tz
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
return nil
}
// getUTCOffset returns the current UTC offset for a location
func getUTCOffset(loc *time.Location) string {
_, offset := time.Now().In(loc).Zone()
hours := offset / 3600
minutes := (offset % 3600) / 60
if minutes < 0 {
minutes = -minutes
}
sign := "+"
if hours < 0 {
sign = "-"
hours = -hours
}
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
}
// Now returns the current time in the configured timezone.
// This is equivalent to time.Now() after Init() is called,
// but provided for explicit timezone-aware code.
func Now() time.Time {
if location == nil {
return time.Now()
}
return time.Now().In(location)
}
// Location returns the configured timezone location.
func Location() *time.Location {
if location == nil {
return time.Local
}
return location
}
// Name returns the configured timezone name.
func Name() string {
if tzName == "" {
return "Local"
}
return tzName
}
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
func StartOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
// Today returns the start of today (00:00:00) in the configured timezone.
func Today() time.Time {
return StartOfDay(Now())
}
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
func EndOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
}
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
func StartOfWeek(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
weekday := int(t.Weekday())
if weekday == 0 {
weekday = 7 // Sunday is day 7
}
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
}
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
func StartOfMonth(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
// ParseInLocation parses a time string in the configured timezone.
func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location())
}
// Package timezone provides global timezone management for the application.
// Similar to PHP's date_default_timezone_set, this package allows setting
// a global timezone that affects all time.Now() calls.
package timezone
import (
"fmt"
"log"
"time"
)
var (
// location is the global timezone location
location *time.Location
// tzName stores the timezone name for logging/debugging
tzName string
)
// Init initializes the global timezone setting.
// This should be called once at application startup.
// Example timezone values: "Asia/Shanghai", "America/New_York", "UTC"
func Init(tz string) error {
if tz == "" {
tz = "Asia/Shanghai" // Default timezone
}
loc, err := time.LoadLocation(tz)
if err != nil {
return fmt.Errorf("invalid timezone %q: %w", tz, err)
}
// Set the global Go time.Local to our timezone
// This affects time.Now() throughout the application
time.Local = loc
location = loc
tzName = tz
log.Printf("Timezone initialized: %s (UTC offset: %s)", tz, getUTCOffset(loc))
return nil
}
// getUTCOffset returns the current UTC offset for a location
func getUTCOffset(loc *time.Location) string {
_, offset := time.Now().In(loc).Zone()
hours := offset / 3600
minutes := (offset % 3600) / 60
if minutes < 0 {
minutes = -minutes
}
sign := "+"
if hours < 0 {
sign = "-"
hours = -hours
}
return fmt.Sprintf("%s%02d:%02d", sign, hours, minutes)
}
// Now returns the current time in the configured timezone.
// This is equivalent to time.Now() after Init() is called,
// but provided for explicit timezone-aware code.
func Now() time.Time {
if location == nil {
return time.Now()
}
return time.Now().In(location)
}
// Location returns the configured timezone location.
func Location() *time.Location {
if location == nil {
return time.Local
}
return location
}
// Name returns the configured timezone name.
func Name() string {
if tzName == "" {
return "Local"
}
return tzName
}
// StartOfDay returns the start of the given day (00:00:00) in the configured timezone.
func StartOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
// Today returns the start of today (00:00:00) in the configured timezone.
func Today() time.Time {
return StartOfDay(Now())
}
// EndOfDay returns the end of the given day (23:59:59.999999999) in the configured timezone.
func EndOfDay(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, loc)
}
// StartOfWeek returns the start of the week (Monday 00:00:00) for the given time.
func StartOfWeek(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
weekday := int(t.Weekday())
if weekday == 0 {
weekday = 7 // Sunday is day 7
}
return time.Date(t.Year(), t.Month(), t.Day()-weekday+1, 0, 0, 0, 0, loc)
}
// StartOfMonth returns the start of the month (1st day 00:00:00) for the given time.
func StartOfMonth(t time.Time) time.Time {
loc := Location()
t = t.In(loc)
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
// ParseInLocation parses a time string in the configured timezone.
func ParseInLocation(layout, value string) (time.Time, error) {
return time.ParseInLocation(layout, value, Location())
}

View File

@@ -1,137 +1,137 @@
package timezone
import (
"testing"
"time"
)
func TestInit(t *testing.T) {
// Test with valid timezone
err := Init("Asia/Shanghai")
if err != nil {
t.Fatalf("Init failed with valid timezone: %v", err)
}
// Verify time.Local was set
if time.Local.String() != "Asia/Shanghai" {
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
}
// Verify our location variable
if Location().String() != "Asia/Shanghai" {
t.Errorf("Location() not set correctly, got %s", Location().String())
}
// Test Name()
if Name() != "Asia/Shanghai" {
t.Errorf("Name() not set correctly, got %s", Name())
}
}
func TestInitInvalidTimezone(t *testing.T) {
err := Init("Invalid/Timezone")
if err == nil {
t.Error("Init should fail with invalid timezone")
}
}
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
if err := Init("UTC"); err != nil {
t.Fatalf("Init failed with UTC: %v", err)
}
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
// Shanghai should be 8 hours ahead in display
_, utcOffset := utcNow.Zone()
_, shanghaiOffset := shanghaiNow.Zone()
expectedDiff := 8 * 3600 // 8 hours in seconds
actualDiff := shanghaiOffset - utcOffset
if actualDiff != expectedDiff {
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
}
}
func TestToday(t *testing.T) {
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
today := Today()
now := Now()
// Today should be at 00:00:00
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
t.Errorf("Today() not at start of day: %v", today)
}
// Today should be same date as now
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
}
}
func TestStartOfDay(t *testing.T) {
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
startOfDay := StartOfDay(testTime)
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
if !startOfDay.Equal(expected) {
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
}
}
func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
now := Now()
// Truncate operates on UTC, not local time
truncated := now.Truncate(24 * time.Hour)
// StartOfDay operates on local time
startOfDay := StartOfDay(now)
// These will likely be different for non-UTC timezones
t.Logf("Now: %v", now)
t.Logf("Truncate(24h): %v", truncated)
t.Logf("StartOfDay: %v", startOfDay)
// The truncated time may not be at local midnight
// StartOfDay is always at local midnight
if startOfDay.Hour() != 0 {
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
}
}
func TestDSTAwareness(t *testing.T) {
// Test with a timezone that has DST (America/New_York)
err := Init("America/New_York")
if err != nil {
t.Skipf("America/New_York timezone not available: %v", err)
}
// Just verify it doesn't crash
_ = Today()
_ = Now()
_ = StartOfDay(Now())
}
package timezone
import (
"testing"
"time"
)
func TestInit(t *testing.T) {
// Test with valid timezone
err := Init("Asia/Shanghai")
if err != nil {
t.Fatalf("Init failed with valid timezone: %v", err)
}
// Verify time.Local was set
if time.Local.String() != "Asia/Shanghai" {
t.Errorf("time.Local not set correctly, got %s", time.Local.String())
}
// Verify our location variable
if Location().String() != "Asia/Shanghai" {
t.Errorf("Location() not set correctly, got %s", Location().String())
}
// Test Name()
if Name() != "Asia/Shanghai" {
t.Errorf("Name() not set correctly, got %s", Name())
}
}
func TestInitInvalidTimezone(t *testing.T) {
err := Init("Invalid/Timezone")
if err == nil {
t.Error("Init should fail with invalid timezone")
}
}
func TestTimeNowAffected(t *testing.T) {
// Reset to UTC first
if err := Init("UTC"); err != nil {
t.Fatalf("Init failed with UTC: %v", err)
}
utcNow := time.Now()
// Switch to Shanghai (UTC+8)
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
shanghaiNow := time.Now()
// The times should be the same instant, but different timezone representation
// Shanghai should be 8 hours ahead in display
_, utcOffset := utcNow.Zone()
_, shanghaiOffset := shanghaiNow.Zone()
expectedDiff := 8 * 3600 // 8 hours in seconds
actualDiff := shanghaiOffset - utcOffset
if actualDiff != expectedDiff {
t.Errorf("Timezone offset difference incorrect: expected %d, got %d", expectedDiff, actualDiff)
}
}
func TestToday(t *testing.T) {
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
today := Today()
now := Now()
// Today should be at 00:00:00
if today.Hour() != 0 || today.Minute() != 0 || today.Second() != 0 {
t.Errorf("Today() not at start of day: %v", today)
}
// Today should be same date as now
if today.Year() != now.Year() || today.Month() != now.Month() || today.Day() != now.Day() {
t.Errorf("Today() date mismatch: today=%v, now=%v", today, now)
}
}
func TestStartOfDay(t *testing.T) {
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
// Create a time at 15:30:45
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
startOfDay := StartOfDay(testTime)
expected := time.Date(2024, 6, 15, 0, 0, 0, 0, Location())
if !startOfDay.Equal(expected) {
t.Errorf("StartOfDay incorrect: expected %v, got %v", expected, startOfDay)
}
}
func TestTruncateVsStartOfDay(t *testing.T) {
// This test demonstrates why Truncate(24*time.Hour) can be problematic
// and why StartOfDay is more reliable for timezone-aware code
if err := Init("Asia/Shanghai"); err != nil {
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
}
now := Now()
// Truncate operates on UTC, not local time
truncated := now.Truncate(24 * time.Hour)
// StartOfDay operates on local time
startOfDay := StartOfDay(now)
// These will likely be different for non-UTC timezones
t.Logf("Now: %v", now)
t.Logf("Truncate(24h): %v", truncated)
t.Logf("StartOfDay: %v", startOfDay)
// The truncated time may not be at local midnight
// StartOfDay is always at local midnight
if startOfDay.Hour() != 0 {
t.Errorf("StartOfDay should be at hour 0, got %d", startOfDay.Hour())
}
}
func TestDSTAwareness(t *testing.T) {
// Test with a timezone that has DST (America/New_York)
err := Init("America/New_York")
if err != nil {
t.Skipf("America/New_York timezone not available: %v", err)
}
// Just verify it doesn't crash
_ = Today()
_ = Now()
_ = StartOfDay(Now())
}

View File

@@ -1,8 +1,8 @@
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}

View File

@@ -1,214 +1,214 @@
package usagestats
import "time"
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
AccountID int64
GroupID int64
Model string
Stream *bool
BillingType *int8
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}
package usagestats
import "time"
// DashboardStats 仪表盘统计
type DashboardStats struct {
// 用户统计
TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
NormalAccounts int64 `json:"normal_accounts"` // 正常账户数 (schedulable=true, status=active)
ErrorAccounts int64 `json:"error_accounts"` // 异常账户数 (status=error)
RateLimitAccounts int64 `json:"ratelimit_accounts"` // 限流账户数
OverloadAccounts int64 `json:"overload_accounts"` // 过载账户数
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint struct {
Date string `json:"date"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheTokens int64 `json:"cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ModelStat represents usage statistics for a single model
type ModelStat struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
TotalTokens int64 `json:"total_tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
UserID int64 `json:"user_id"`
Email string `json:"email"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` // 标准计费
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` // 累计标准计费
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
TodayInputTokens int64 `json:"today_input_tokens"`
TodayOutputTokens int64 `json:"today_output_tokens"`
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
TodayCost float64 `json:"today_cost"` // 今日标准计费
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
// 性能统计
AverageDurationMs float64 `json:"average_duration_ms"`
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
}
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
ApiKeyID int64
AccountID int64
GroupID int64
Model string
Stream *bool
BillingType *int8
StartTime *time.Time
EndTime *time.Time
}
// UsageStats represents usage statistics
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
ActualCost float64 `json:"actual_cost"`
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"`
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
HighestCostDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
Date string `json:"date"`
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
} `json:"highest_request_day"`
}
// AccountUsageStatsResponse represents the full usage statistics response for an account
type AccountUsageStatsResponse struct {
History []AccountUsageHistory `json:"history"`
Summary AccountUsageSummary `json:"summary"`
Models []ModelStat `json:"models"`
}