feat(antigravity): 增强网关功能和 thinking 块处理

主要改进:
- 优化 thinking blocks 过滤策略,支持 Auto 模式降级
- 将无效 thinking block 内容转为普通 text
- 保留单个空白 text block,不过滤
- 重构配额刷新机制,统一与 Claude 一致
- 支持 cachedContentTokenCount 映射到 cache_read_input_tokens
- Haiku 模型映射到 Sonnet
- 添加 /antigravity/v1/models 端点支持
- countTokens 端点直接返回空值
This commit is contained in:
ianshaw
2026-01-03 06:29:02 -08:00
parent df1ef3deb6
commit 26438f7232
15 changed files with 463 additions and 358 deletions

View File

@@ -1,5 +1,3 @@
// Package handler provides HTTP request handlers for the API gateway.
// It handles authentication, request routing, concurrency control, and billing validation.
package handler
import (
@@ -13,6 +11,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -29,7 +28,6 @@ type GatewayHandler struct {
userService *service.UserService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
opsService *service.OpsService
}
// NewGatewayHandler creates a new GatewayHandler
@@ -40,7 +38,6 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
opsService *service.OpsService,
) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
@@ -49,15 +46,14 @@ func NewGatewayHandler(
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
opsService: opsService,
}
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userAPIKeyAuth中间件已设置
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -92,7 +88,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
setOpsRequestContext(c, reqModel, reqStream)
// 验证 model 必填
if reqModel == "" {
@@ -264,7 +259,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -388,7 +383,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -405,7 +400,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Returns models based on account configurations (model_mapping whitelist)
// Falls back to default models if no whitelist is configured
func (h *GatewayHandler) Models(c *gin.Context) {
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
apiKey, _ := middleware2.GetApiKeyFromContext(c)
var groupID *int64
var platform string
@@ -451,10 +446,19 @@ func (h *GatewayHandler) Models(c *gin.Context) {
})
}
// AntigravityModels 返回 Antigravity 支持的全部模型
// GET /antigravity/models
func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": antigravity.DefaultModels(),
})
}
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -579,7 +583,6 @@ func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string)
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
recordOpsError(c, h.opsService, status, errType, message, "")
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
@@ -611,7 +614,6 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
recordOpsError(c, h.opsService, status, errType, message, "")
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
@@ -625,8 +627,8 @@ func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, mess
// POST /v1/messages/count_tokens
// 特点:校验订阅/余额,但不计算并发、不记录使用量
func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 从context获取apiKey和userAPIKeyAuth中间件已设置
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return

View File

@@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -20,7 +21,7 @@ import (
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetAPIKeyFromContext(c)
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -32,9 +33,9 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
return
}
// 强制 antigravity 模式:直接返回静态模型列表
// 强制 antigravity 模式:返回 antigravity 支持的模型列表
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
c.JSON(http.StatusOK, antigravity.FallbackGeminiModelsList())
return
}
@@ -66,7 +67,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetAPIKeyFromContext(c)
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -84,9 +85,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
return
}
// 强制 antigravity 模式:直接返回静态模型信息
// 强制 antigravity 模式:返回 antigravity 模型信息
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
c.JSON(http.StatusOK, antigravity.FallbackGeminiModel(modelName))
return
}
@@ -119,7 +120,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetAPIKeyFromContext(c)
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
@@ -298,7 +299,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,

View File

@@ -138,3 +138,91 @@ type ErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`
}
// modelDef Antigravity 模型定义(内部使用)
type modelDef struct {
ID string
DisplayName string
CreatedAt string // 仅 Claude API 格式使用
}
// Antigravity 支持的 Claude 模型
var claudeModels = []modelDef{
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
}
// Antigravity 支持的 Gemini 模型
var geminiModels = []modelDef{
{ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"},
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}
// ========== Claude API 格式 (/v1/models) ==========
// ClaudeModel Claude API 模型格式
type ClaudeModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
CreatedAt string `json:"created_at"`
}
// DefaultModels 返回 Claude API 格式的模型列表Claude + Gemini
func DefaultModels() []ClaudeModel {
all := append(claudeModels, geminiModels...)
result := make([]ClaudeModel, len(all))
for i, m := range all {
result[i] = ClaudeModel{ID: m.ID, Type: "model", DisplayName: m.DisplayName, CreatedAt: m.CreatedAt}
}
return result
}
// ========== Gemini v1beta 格式 (/v1beta/models) ==========
// GeminiModel Gemini v1beta 模型格式
type GeminiModel struct {
Name string `json:"name"`
DisplayName string `json:"displayName,omitempty"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
}
// GeminiModelsListResponse Gemini v1beta 模型列表响应
type GeminiModelsListResponse struct {
Models []GeminiModel `json:"models"`
}
var defaultGeminiMethods = []string{"generateContent", "streamGenerateContent"}
// DefaultGeminiModels 返回 Gemini v1beta 格式的模型列表(仅 Gemini 模型)
func DefaultGeminiModels() []GeminiModel {
result := make([]GeminiModel, len(geminiModels))
for i, m := range geminiModels {
result[i] = GeminiModel{Name: "models/" + m.ID, DisplayName: m.DisplayName, SupportedGenerationMethods: defaultGeminiMethods}
}
return result
}
// FallbackGeminiModelsList 返回 Gemini v1beta 格式的模型列表响应
func FallbackGeminiModelsList() GeminiModelsListResponse {
return GeminiModelsListResponse{Models: DefaultGeminiModels()}
}
// FallbackGeminiModel 返回单个模型信息v1beta 格式)
func FallbackGeminiModel(model string) GeminiModel {
if model == "" {
return GeminiModel{Name: "models/unknown", SupportedGenerationMethods: defaultGeminiMethods}
}
name := model
if len(model) < 7 || model[:7] != "models/" {
name = "models/" + model
}
return GeminiModel{Name: name, SupportedGenerationMethods: defaultGeminiMethods}
}

View File

@@ -1,5 +1,3 @@
// Package antigravity provides a client for interacting with Google's Antigravity API,
// handling OAuth authentication, token management, and account tier information retrieval.
package antigravity
import (
@@ -59,6 +57,29 @@ type TierInfo struct {
Description string `json:"description"` // 描述
}
// UnmarshalJSON supports both legacy string tiers and object tiers.
func (t *TierInfo) UnmarshalJSON(data []byte) error {
data = bytes.TrimSpace(data)
if len(data) == 0 || string(data) == "null" {
return nil
}
if data[0] == '"' {
var id string
if err := json.Unmarshal(data, &id); err != nil {
return err
}
t.ID = id
return nil
}
type alias TierInfo
var decoded alias
if err := json.Unmarshal(data, &decoded); err != nil {
return err
}
*t = TierInfo(decoded)
return nil
}
// IneligibleTier 不符合条件的层级信息
type IneligibleTier struct {
Tier *TierInfo `json:"tier,omitempty"`

View File

@@ -143,9 +143,10 @@ type GeminiCandidate struct {
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)

View File

@@ -150,13 +150,18 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 历史 assistant 消息不能添加没有 signature 的 dummy thinking block
if allowDummyThought && role == "model" && isThinkingEnabled && i == len(messages)-1 {
hasThoughtPart := false
for _, p := range parts {
firstPartIsThought := false
for idx, p := range parts {
if p.Thought {
hasThoughtPart = true
if idx == 0 {
firstPartIsThought = true
}
break
}
}
if !hasThoughtPart && len(parts) > 0 {
// 如果没有thinking part或者有thinking part但不在第一个位置都需要在开头添加dummy thinking block
if len(parts) > 0 && (!hasThoughtPart || !firstPartIsThought) {
// 在开头添加 dummy thinking block
parts = append([]GeminiPart{{
Text: "Thinking...",
@@ -236,6 +241,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, thought
// Claude via Vertex
// - signature 是上游返回的完整性令牌;本地不需要/无法验证,只能透传
// - 缺失/无效 signature例如来自 Gemini 的 dummy signature会导致上游 400
// - 为避免泄露 thinking 内容,缺失/无效 signature 的 thinking 直接丢弃
if signature == "" || signature == dummyThoughtSignature {
continue
}
@@ -429,7 +435,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for i, tool := range tools {
for _, tool := range tools {
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
@@ -448,10 +454,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
@@ -468,11 +470,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: description,
@@ -627,20 +624,16 @@ func cleanSchemaValue(value any) any {
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值
result[k] = cleanSchemaValue(val)
}
return result
case []any:
// 递归处理数组中的每个元素
cleaned := make([]any, 0, len(v))

View File

@@ -15,15 +15,15 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
description string
}{
{
name: "Claude model - skip thinking block without signature",
name: "Claude model - drop thinking without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
thoughtMode: thoughtSignatureModePreserve,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
expectedParts: 2, // thinking 内容被丢弃
description: "Claude模型应丢弃无signature的thinking block内容",
},
{
name: "Claude model - preserve thinking block with signature",

View File

@@ -232,10 +232,18 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
stopReason = "max_tokens"
}
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
usage := ClaudeUsage{}
if geminiResp.UsageMetadata != nil {
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount
cached := geminiResp.UsageMetadata.CachedContentTokenCount
prompt := geminiResp.UsageMetadata.PromptTokenCount
if cached > prompt {
cached = prompt
}
usage.InputTokens = prompt - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
// 生成响应 ID

View File

@@ -29,8 +29,9 @@ type StreamingProcessor struct {
originalModel string
// 累计 usage
inputTokens int
outputTokens int
inputTokens int
outputTokens int
cacheReadTokens int
}
// NewStreamingProcessor 创建流式响应处理器
@@ -76,9 +77,17 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
}
// 更新 usage
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
if geminiResp.UsageMetadata != nil {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount
cached := geminiResp.UsageMetadata.CachedContentTokenCount
prompt := geminiResp.UsageMetadata.PromptTokenCount
if cached > prompt {
cached = prompt
}
p.inputTokens = prompt - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
p.cacheReadTokens = cached
}
// 处理 parts
@@ -108,8 +117,9 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
}
usage := &ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
return result.Bytes(), usage
@@ -123,8 +133,14 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage := ClaudeUsage{}
if v1Resp.Response.UsageMetadata != nil {
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
prompt := v1Resp.Response.UsageMetadata.PromptTokenCount
if cached > prompt {
cached = prompt
}
usage.InputTokens = prompt - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
responseID := v1Resp.ResponseID
@@ -418,8 +434,9 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
}
usage := ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
deltaEvent := map[string]any{

View File

@@ -13,8 +13,8 @@ import (
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
@@ -36,7 +36,7 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -47,6 +47,9 @@ func RegisterGatewayRoutes(
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
@@ -55,14 +58,14 @@ func RegisterGatewayRoutes(
{
antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
antigravityV1.GET("/models", h.Gateway.Models)
antigravityV1.GET("/models", h.Gateway.AntigravityModels)
antigravityV1.GET("/usage", h.Gateway.Usage)
}
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)

View File

@@ -49,11 +49,11 @@ var antigravityPrefixMapping = []struct {
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
@@ -64,6 +64,7 @@ type AntigravityGatewayService struct {
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
settingService *SettingService
}
func NewAntigravityGatewayService(
@@ -72,12 +73,14 @@ func NewAntigravityGatewayService(
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
settingService *SettingService,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
settingService: settingService,
}
}
@@ -308,6 +311,7 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
}
// isSignatureRelatedError 检测是否为 signature 相关的 400 错误
// 注意:不包含 "thinking" 关键词,避免误判消息格式错误为 signature 错误
func isSignatureRelatedError(statusCode int, body []byte) bool {
if statusCode != 400 {
return false
@@ -318,7 +322,6 @@ func isSignatureRelatedError(statusCode int, body []byte) bool {
"signature",
"thought_signature",
"thoughtsignature",
"thinking",
"invalid signature",
"signature validation",
}
@@ -331,28 +334,60 @@ func isSignatureRelatedError(statusCode int, body []byte) bool {
return false
}
// stripThinkingFromClaudeRequest 从 Claude 请求中移除所有 thinking 相关内容
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{
"model not found",
"model does not exist",
"unknown model",
"invalid model",
}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return false
}
// stripThinkingFromClaudeRequest 从 Claude 请求中移除有问题的 thinking 块
// 策略:只移除历史消息中带 dummy signature 的 thinking 块,保留本次 thinking 配置
// 这样可以让本次对话仍然使用 thinking 功能,只是清理历史中可能导致问题的内容
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) *antigravity.ClaudeRequest {
// 创建副本
stripped := *req
// 移除 thinking 配置
stripped.Thinking = nil
// 保留 thinking 配置,让本次对话仍然可以使用 thinking
// stripped.Thinking = nil // 不再移除
// 移除消息中的 thinking 块
// 移除消息中带 dummy signature 的 thinking 块
if len(stripped.Messages) > 0 {
newMessages := make([]antigravity.ClaudeMessage, 0, len(stripped.Messages))
for _, msg := range stripped.Messages {
newMsg := msg
// 如果 content 是数组,过滤 thinking 块
// 如果 content 是数组,过滤有问题的 thinking 块
var blocks []map[string]any
if err := json.Unmarshal(msg.Content, &blocks); err == nil {
filtered := make([]map[string]any, 0, len(blocks))
for _, block := range blocks {
// 跳过有 type="thinking" 的
// 跳过带 dummy signature 的 thinking
if blockType, ok := block["type"].(string); ok && blockType == "thinking" {
continue
if sig, ok := block["signature"].(string); ok {
// 移除 dummy signature 的 thinking 块
if sig == "skip_thought_signature_validator" || sig == "" {
continue
}
} else {
// 没有 signature 字段的 thinking 块也移除
continue
}
}
// 跳过没有 type 但有 thinking 字段的块untyped thinking blocks
if _, hasType := block["type"]; !hasType {
@@ -390,9 +425,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
if mappedModel != claudeReq.Model {
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name)
}
// 获取 access_token
if s.tokenProvider == nil {
@@ -418,15 +450,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {
@@ -495,7 +518,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
if err != nil {
log.Printf("[Antigravity] Failed to transform stripped request: %v", err)
// 降级失败,返回原始错误
if s.shouldFailoverUpstreamError(resp.StatusCode) {
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
@@ -505,7 +528,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
retryReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, strippedBody)
if err != nil {
log.Printf("[Antigravity] Failed to create retry request: %v", err)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
@@ -514,7 +537,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
retryResp, err := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("[Antigravity] Retry request failed: %v", err)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
@@ -531,7 +554,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
log.Printf("[Antigravity] Retry also failed with status %d: %s", retryResp.StatusCode, string(retryRespBody))
s.handleUpstreamError(ctx, account, retryResp.StatusCode, retryResp.Header, retryRespBody)
if s.shouldFailoverUpstreamError(retryResp.StatusCode) {
if s.shouldFailoverWithTempUnsched(ctx, account, retryResp.StatusCode, retryRespBody) {
return nil, &UpstreamFailoverError{StatusCode: retryResp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, retryResp.StatusCode, retryRespBody)
@@ -540,7 +563,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// 不是 signature 错误,或者已经没有 thinking 块,直接返回错误
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -594,8 +617,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
case "generateContent", "streamGenerateContent":
// ok
case "countTokens":
return nil, s.writeGoogleError(c, http.StatusNotImplemented, "countTokens is not supported")
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
@@ -650,18 +675,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
sleepAntigravityBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
@@ -678,18 +691,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
@@ -712,20 +713,42 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
// Check if model fallback is enabled and this is a model not found error
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
isModelNotFoundError(resp.StatusCode, respBody) {
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
// Only retry if fallback model is different from current model
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)",
mappedModel, fallbackModel, account.Name)
// Close original response
_ = resp.Body.Close()
// Rebuild request with fallback model
fallbackBody, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackBody)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
log.Printf("[Antigravity] Fallback succeeded with %s (account: %s)", fallbackModel, account.Name)
resp = fallbackResp
originalModel = fallbackModel // Update for billing
// Continue to normal response handling
goto handleSuccess
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
}
}
}
log.Printf("[Antigravity] Fallback failed, returning original error")
}
}
if s.shouldFailoverUpstreamError(resp.StatusCode) {
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -739,6 +762,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
handleSuccess:
var usage *ClaudeUsage
var firstTokenMs *int
@@ -789,6 +813,15 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
}
}
func (s *AntigravityGatewayService) shouldFailoverWithTempUnsched(ctx context.Context, account *Account, statusCode int, body []byte) bool {
if s.rateLimitService != nil {
if s.rateLimitService.HandleTempUnschedulable(ctx, account, statusCode, body) {
return true
}
}
return s.shouldFailoverUpstreamError(statusCode)
}
func sleepAntigravityBackoff(attempt int) {
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
}
@@ -899,7 +932,10 @@ func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Cont
}
// 解包 v1internal 响应
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
unwrapped := respBody
if inner, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && inner != nil {
unwrapped = inner
}
var parsed map[string]any
if json.Unmarshal(unwrapped, &parsed) == nil {
@@ -973,6 +1009,8 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
statusStr = "RESOURCE_EXHAUSTED"
case 500:
statusStr = "INTERNAL"
case 501:
statusStr = "UNIMPLEMENTED"
case 502, 503:
statusStr = "UNAVAILABLE"
}

View File

@@ -104,28 +104,28 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-haiku-4 → gemini-3-flash",
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5 → gemini-3-flash",
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-haiku-20240307 → gemini-3-flash",
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
requestedModel: "claude-3-haiku-20240307",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash",
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",

View File

@@ -0,0 +1,134 @@
package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
type AntigravityQuotaFetcher struct {
proxyRepo ProxyRepository
}
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
}
// CanFetch 检查是否可以获取此账户的额度
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
if f == nil || account == nil {
return false
}
if account.Platform != PlatformAntigravity {
return false
}
accessToken := account.GetCredential("access_token")
return accessToken != ""
}
// FetchQuota 获取 Antigravity 账户额度信息
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
if f == nil {
return nil, fmt.Errorf("antigravity quota fetcher is nil")
}
if account == nil {
return nil, fmt.Errorf("account is nil")
}
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
// 如果没有 project_id生成一个随机的
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
}
client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
return nil, err
}
// 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp)
return &QuotaResult{
UsageInfo: usageInfo,
Raw: modelsRaw,
}, nil
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
}
if modelsResp == nil {
return info
}
// 遍历所有模型,填充 AntigravityQuota
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
}
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
utilization := clampInt(int((1.0-modelInfo.QuotaInfo.RemainingFraction)*100), 0, 100)
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
Utilization: utilization,
ResetTime: modelInfo.QuotaInfo.ResetTime,
}
}
// 同时设置 FiveHour 用于兼容展示(取主要模型)
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
for _, modelName := range priorityModels {
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
utilization := clampFloat64((1.0-modelInfo.QuotaInfo.RemainingFraction)*100, 0, 100)
progress := &UsageProgress{
Utilization: utilization,
}
if modelInfo.QuotaInfo.ResetTime != "" {
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
progress.ResetsAt = &resetTime
progress.RemainingSeconds = remainingSecondsUntil(resetTime)
}
}
info.FiveHour = progress
break
}
}
return info
}
// GetProxyURL 获取账户的代理 URL
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) (string, error) {
if f == nil {
return "", fmt.Errorf("antigravity quota fetcher is nil")
}
if account == nil {
return "", fmt.Errorf("account is nil")
}
if account.ProxyID == nil || f.proxyRepo == nil {
return "", nil
}
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
if err != nil {
return "", err
}
if proxy == nil {
return "", nil
}
return proxy.URL(), nil
}

View File

@@ -1,222 +0,0 @@
package service
import (
"context"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
type AntigravityQuotaRefresher struct {
accountRepo AccountRepository
proxyRepo ProxyRepository
cfg *config.TokenRefreshConfig
stopCh chan struct{}
wg sync.WaitGroup
}
// NewAntigravityQuotaRefresher 创建配额刷新器
func NewAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
_ *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
return &AntigravityQuotaRefresher{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
}
}
// Start 启动后台配额刷新服务
func (r *AntigravityQuotaRefresher) Start() {
if !r.cfg.Enabled {
log.Println("[AntigravityQuota] Service disabled by configuration")
return
}
r.wg.Add(1)
go r.refreshLoop()
log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes)
}
// Stop 停止服务
func (r *AntigravityQuotaRefresher) Stop() {
close(r.stopCh)
r.wg.Wait()
log.Println("[AntigravityQuota] Service stopped")
}
// refreshLoop 刷新循环
func (r *AntigravityQuotaRefresher) refreshLoop() {
defer r.wg.Done()
checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次
r.processRefresh()
for {
select {
case <-ticker.C:
r.processRefresh()
case <-r.stopCh:
return
}
}
}
// processRefresh 执行一次刷新
func (r *AntigravityQuotaRefresher) processRefresh() {
ctx := context.Background()
// 查询所有 active 的账户,然后过滤 antigravity 平台
allAccounts, err := r.accountRepo.ListActive(ctx)
if err != nil {
log.Printf("[AntigravityQuota] Failed to list accounts: %v", err)
return
}
// 过滤 antigravity 平台账户
var accounts []Account
for _, acc := range allAccounts {
if acc.Platform == PlatformAntigravity {
accounts = append(accounts, acc)
}
}
if len(accounts) == 0 {
return
}
refreshed, failed := 0, 0
for i := range accounts {
account := &accounts[i]
if err := r.refreshAccountQuota(ctx, account); err != nil {
log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err)
failed++
} else {
refreshed++
}
}
log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d",
len(accounts), refreshed, failed)
}
// refreshAccountQuota 刷新单个账户的配额
func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error {
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
if accessToken == "" {
return nil // 没有 access_token跳过
}
// token 过期则跳过,由 TokenRefreshService 负责刷新
if r.isTokenExpired(account) {
return nil
}
// 获取代理 URL
var proxyURL string
if account.ProxyID != nil {
proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client := antigravity.NewClient(proxyURL)
if account.Extra == nil {
account.Extra = make(map[string]any)
}
// 获取账户信息tier、project_id 等)
loadResp, loadRaw, _ := client.LoadCodeAssist(ctx, accessToken)
if loadRaw != nil {
account.Extra["load_code_assist"] = loadRaw
}
if loadResp != nil {
// 尝试从 API 获取 project_id
if projectID == "" && loadResp.CloudAICompanionProject != "" {
projectID = loadResp.CloudAICompanionProject
account.Credentials["project_id"] = projectID
}
}
// 如果仍然没有 project_id随机生成一个并保存
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
account.Credentials["project_id"] = projectID
log.Printf("[AntigravityQuotaRefresher] 为账户 %d 生成随机 project_id: %s", account.ID, projectID)
}
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
return r.accountRepo.Update(ctx, account) // 保存已有的 load_code_assist 信息
}
// 保存完整的配额响应
if modelsRaw != nil {
account.Extra["available_models"] = modelsRaw
}
// 解析配额数据为前端使用的格式
r.updateAccountQuota(account, modelsResp)
account.Extra["last_refresh"] = time.Now().Format(time.RFC3339)
// 保存到数据库
return r.accountRepo.Update(ctx, account)
}
// isTokenExpired 检查 token 是否过期
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
// 提前 5 分钟认为过期
return time.Now().Add(5 * time.Minute).After(*expiresAt)
}
// updateAccountQuota 更新账户的配额信息(前端使用的格式)
func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) {
quota := make(map[string]any)
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
}
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100)
quota[modelName] = map[string]any{
"remaining": remaining,
"reset_time": modelInfo.QuotaInfo.ResetTime,
}
}
account.Extra["quota"] = quota
}

View File

@@ -0,0 +1,21 @@
package service
import (
"context"
)
// QuotaFetcher 额度获取接口,各平台实现此接口
type QuotaFetcher interface {
// CanFetch 检查是否可以获取此账户的额度
CanFetch(account *Account) bool
// GetProxyURL 获取账户的代理 URL如果没有代理则返回空字符串
GetProxyURL(ctx context.Context, account *Account) (string, error)
// FetchQuota 获取账户额度信息
FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error)
}
// QuotaResult 额度获取结果
type QuotaResult struct {
UsageInfo *UsageInfo // 转换后的使用信息
Raw map[string]any // 原始响应,可存入 account.Extra
}