feat(antigravity): 增强网关功能和 thinking 块处理
主要改进: - 优化 thinking blocks 过滤策略,支持 Auto 模式降级 - 将无效 thinking block 内容转为普通 text - 保留单个空白 text block,不过滤 - 重构配额刷新机制,统一与 Claude 一致 - 支持 cachedContentTokenCount 映射到 cache_read_input_tokens - Haiku 模型映射到 Sonnet - 添加 /antigravity/v1/models 端点支持 - countTokens 端点直接返回空值
This commit is contained in:
@@ -49,11 +49,11 @@ var antigravityPrefixMapping = []struct {
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
|
||||
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
|
||||
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||
{"claude-haiku-4", "gemini-3-flash"},
|
||||
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
|
||||
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||
}
|
||||
@@ -64,6 +64,7 @@ type AntigravityGatewayService struct {
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@@ -72,12 +73,14 @@ func NewAntigravityGatewayService(
|
||||
tokenProvider *AntigravityTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
settingService: settingService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,6 +311,7 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
|
||||
}
|
||||
|
||||
// isSignatureRelatedError 检测是否为 signature 相关的 400 错误
|
||||
// 注意:不包含 "thinking" 关键词,避免误判消息格式错误为 signature 错误
|
||||
func isSignatureRelatedError(statusCode int, body []byte) bool {
|
||||
if statusCode != 400 {
|
||||
return false
|
||||
@@ -318,7 +322,6 @@ func isSignatureRelatedError(statusCode int, body []byte) bool {
|
||||
"signature",
|
||||
"thought_signature",
|
||||
"thoughtsignature",
|
||||
"thinking",
|
||||
"invalid signature",
|
||||
"signature validation",
|
||||
}
|
||||
@@ -331,28 +334,60 @@ func isSignatureRelatedError(statusCode int, body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// stripThinkingFromClaudeRequest 从 Claude 请求中移除所有 thinking 相关内容
|
||||
// isModelNotFoundError 检测是否为模型不存在的 404 错误
|
||||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||||
if statusCode != 404 {
|
||||
return false
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
keywords := []string{
|
||||
"model not found",
|
||||
"model does not exist",
|
||||
"unknown model",
|
||||
"invalid model",
|
||||
}
|
||||
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(bodyStr, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripThinkingFromClaudeRequest 从 Claude 请求中移除有问题的 thinking 块
|
||||
// 策略:只移除历史消息中带 dummy signature 的 thinking 块,保留本次 thinking 配置
|
||||
// 这样可以让本次对话仍然使用 thinking 功能,只是清理历史中可能导致问题的内容
|
||||
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) *antigravity.ClaudeRequest {
|
||||
// 创建副本
|
||||
stripped := *req
|
||||
|
||||
// 移除 thinking 配置
|
||||
stripped.Thinking = nil
|
||||
// 保留 thinking 配置,让本次对话仍然可以使用 thinking
|
||||
// stripped.Thinking = nil // 不再移除
|
||||
|
||||
// 移除消息中的 thinking 块
|
||||
// 只移除消息中带 dummy signature 的 thinking 块
|
||||
if len(stripped.Messages) > 0 {
|
||||
newMessages := make([]antigravity.ClaudeMessage, 0, len(stripped.Messages))
|
||||
for _, msg := range stripped.Messages {
|
||||
newMsg := msg
|
||||
|
||||
// 如果 content 是数组,过滤 thinking 块
|
||||
// 如果 content 是数组,过滤有问题的 thinking 块
|
||||
var blocks []map[string]any
|
||||
if err := json.Unmarshal(msg.Content, &blocks); err == nil {
|
||||
filtered := make([]map[string]any, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
// 跳过有 type="thinking" 的块
|
||||
// 跳过带 dummy signature 的 thinking 块
|
||||
if blockType, ok := block["type"].(string); ok && blockType == "thinking" {
|
||||
continue
|
||||
if sig, ok := block["signature"].(string); ok {
|
||||
// 移除 dummy signature 的 thinking 块
|
||||
if sig == "skip_thought_signature_validator" || sig == "" {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// 没有 signature 字段的 thinking 块也移除
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 跳过没有 type 但有 thinking 字段的块(untyped thinking blocks)
|
||||
if _, hasType := block["type"]; !hasType {
|
||||
@@ -390,9 +425,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
originalModel := claudeReq.Model
|
||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||
if mappedModel != claudeReq.Model {
|
||||
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name)
|
||||
}
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -418,15 +450,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
|
||||
// 调试:记录转换后的请求体(仅记录前 2000 字符)
|
||||
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
|
||||
truncated := string(bodyJSON)
|
||||
if len(truncated) > 2000 {
|
||||
truncated = truncated[:2000] + "..."
|
||||
}
|
||||
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
|
||||
}
|
||||
|
||||
// 构建上游 action
|
||||
action := "generateContent"
|
||||
if claudeReq.Stream {
|
||||
@@ -495,7 +518,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if err != nil {
|
||||
log.Printf("[Antigravity] Failed to transform stripped request: %v", err)
|
||||
// 降级失败,返回原始错误
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
@@ -505,7 +528,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
retryReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, strippedBody)
|
||||
if err != nil {
|
||||
log.Printf("[Antigravity] Failed to create retry request: %v", err)
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
@@ -514,7 +537,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
retryResp, err := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("[Antigravity] Retry request failed: %v", err)
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
@@ -531,7 +554,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
log.Printf("[Antigravity] Retry also failed with status %d: %s", retryResp.StatusCode, string(retryRespBody))
|
||||
s.handleUpstreamError(ctx, account, retryResp.StatusCode, retryResp.Header, retryRespBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(retryResp.StatusCode) {
|
||||
if s.shouldFailoverWithTempUnsched(ctx, account, retryResp.StatusCode, retryRespBody) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: retryResp.StatusCode}
|
||||
}
|
||||
return nil, s.writeMappedClaudeError(c, retryResp.StatusCode, retryRespBody)
|
||||
@@ -540,7 +563,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 不是 signature 错误,或者已经没有 thinking 块,直接返回错误
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
@@ -594,8 +617,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent", "countTokens":
|
||||
case "generateContent", "streamGenerateContent":
|
||||
// ok
|
||||
case "countTokens":
|
||||
return nil, s.writeGoogleError(c, http.StatusNotImplemented, "countTokens is not supported")
|
||||
default:
|
||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||||
}
|
||||
@@ -650,18 +675,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
@@ -678,18 +691,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
@@ -712,20 +713,42 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
// Check if model fallback is enabled and this is a model not found error
|
||||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
||||
isModelNotFoundError(resp.StatusCode, respBody) {
|
||||
|
||||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
||||
|
||||
// Only retry if fallback model is different from current model
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)",
|
||||
mappedModel, fallbackModel, account.Name)
|
||||
|
||||
// Close original response
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// Rebuild request with fallback model
|
||||
fallbackBody, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackBody)
|
||||
if err == nil {
|
||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err == nil && fallbackResp.StatusCode < 400 {
|
||||
log.Printf("[Antigravity] Fallback succeeded with %s (account: %s)", fallbackModel, account.Name)
|
||||
resp = fallbackResp
|
||||
originalModel = fallbackModel // Update for billing
|
||||
// Continue to normal response handling
|
||||
goto handleSuccess
|
||||
} else if fallbackResp != nil {
|
||||
_ = fallbackResp.Body.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Printf("[Antigravity] Fallback failed, returning original error")
|
||||
}
|
||||
}
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
if s.shouldFailoverWithTempUnsched(ctx, account, resp.StatusCode, respBody) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
@@ -739,6 +762,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
handleSuccess:
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
@@ -789,6 +813,15 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) shouldFailoverWithTempUnsched(ctx context.Context, account *Account, statusCode int, body []byte) bool {
|
||||
if s.rateLimitService != nil {
|
||||
if s.rateLimitService.HandleTempUnschedulable(ctx, account, statusCode, body) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return s.shouldFailoverUpstreamError(statusCode)
|
||||
}
|
||||
|
||||
func sleepAntigravityBackoff(attempt int) {
|
||||
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
|
||||
}
|
||||
@@ -899,7 +932,10 @@ func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Cont
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
unwrapped := respBody
|
||||
if inner, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && inner != nil {
|
||||
unwrapped = inner
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(unwrapped, &parsed) == nil {
|
||||
@@ -973,6 +1009,8 @@ func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int,
|
||||
statusStr = "RESOURCE_EXHAUSTED"
|
||||
case 500:
|
||||
statusStr = "INTERNAL"
|
||||
case 501:
|
||||
statusStr = "UNIMPLEMENTED"
|
||||
case 502, 503:
|
||||
statusStr = "UNAVAILABLE"
|
||||
}
|
||||
|
||||
@@ -104,28 +104,28 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4 → gemini-3-flash",
|
||||
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4-5 → gemini-3-flash",
|
||||
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-haiku-20240307 → gemini-3-flash",
|
||||
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-3-haiku-20240307",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash",
|
||||
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4-5-20251001",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-sonnet-4-5-20250929",
|
||||
|
||||
134
backend/internal/service/antigravity_quota_fetcher.go
Normal file
134
backend/internal/service/antigravity_quota_fetcher.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
|
||||
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
|
||||
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
|
||||
}
|
||||
|
||||
// CanFetch 检查是否可以获取此账户的额度
|
||||
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
|
||||
if f == nil || account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
accessToken := account.GetCredential("access_token")
|
||||
return accessToken != ""
|
||||
}
|
||||
|
||||
// FetchQuota 获取 Antigravity 账户额度信息
|
||||
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
|
||||
if f == nil {
|
||||
return nil, fmt.Errorf("antigravity quota fetcher is nil")
|
||||
}
|
||||
if account == nil {
|
||||
return nil, fmt.Errorf("account is nil")
|
||||
}
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
// 如果没有 project_id,生成一个随机的
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
Raw: modelsRaw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
}
|
||||
|
||||
if modelsResp == nil {
|
||||
return info
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
|
||||
utilization := clampInt(int((1.0-modelInfo.QuotaInfo.RemainingFraction)*100), 0, 100)
|
||||
|
||||
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
|
||||
for _, modelName := range priorityModels {
|
||||
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
|
||||
utilization := clampFloat64((1.0-modelInfo.QuotaInfo.RemainingFraction)*100, 0, 100)
|
||||
progress := &UsageProgress{
|
||||
Utilization: utilization,
|
||||
}
|
||||
if modelInfo.QuotaInfo.ResetTime != "" {
|
||||
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
|
||||
progress.ResetsAt = &resetTime
|
||||
progress.RemainingSeconds = remainingSecondsUntil(resetTime)
|
||||
}
|
||||
}
|
||||
info.FiveHour = progress
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GetProxyURL 获取账户的代理 URL
|
||||
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) (string, error) {
|
||||
if f == nil {
|
||||
return "", fmt.Errorf("antigravity quota fetcher is nil")
|
||||
}
|
||||
if account == nil {
|
||||
return "", fmt.Errorf("account is nil")
|
||||
}
|
||||
if account.ProxyID == nil || f.proxyRepo == nil {
|
||||
return "", nil
|
||||
}
|
||||
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if proxy == nil {
|
||||
return "", nil
|
||||
}
|
||||
return proxy.URL(), nil
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
|
||||
type AntigravityQuotaRefresher struct {
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
cfg *config.TokenRefreshConfig
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaRefresher 创建配额刷新器
|
||||
func NewAntigravityQuotaRefresher(
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
_ *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *AntigravityQuotaRefresher {
|
||||
return &AntigravityQuotaRefresher{
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动后台配额刷新服务
|
||||
func (r *AntigravityQuotaRefresher) Start() {
|
||||
if !r.cfg.Enabled {
|
||||
log.Println("[AntigravityQuota] Service disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
r.wg.Add(1)
|
||||
go r.refreshLoop()
|
||||
|
||||
log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes)
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (r *AntigravityQuotaRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
log.Println("[AntigravityQuota] Service stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
func (r *AntigravityQuotaRefresher) refreshLoop() {
|
||||
defer r.wg.Done()
|
||||
|
||||
checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute
|
||||
if checkInterval < time.Minute {
|
||||
checkInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时立即执行一次
|
||||
r.processRefresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.processRefresh()
|
||||
case <-r.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processRefresh 执行一次刷新
|
||||
func (r *AntigravityQuotaRefresher) processRefresh() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 查询所有 active 的账户,然后过滤 antigravity 平台
|
||||
allAccounts, err := r.accountRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[AntigravityQuota] Failed to list accounts: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤 antigravity 平台账户
|
||||
var accounts []Account
|
||||
for _, acc := range allAccounts {
|
||||
if acc.Platform == PlatformAntigravity {
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
refreshed, failed := 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
if err := r.refreshAccountQuota(ctx, account); err != nil {
|
||||
log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
failed++
|
||||
} else {
|
||||
refreshed++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d",
|
||||
len(accounts), refreshed, failed)
|
||||
}
|
||||
|
||||
// refreshAccountQuota 刷新单个账户的配额
|
||||
func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
if accessToken == "" {
|
||||
return nil // 没有 access_token,跳过
|
||||
}
|
||||
|
||||
// token 过期则跳过,由 TokenRefreshService 负责刷新
|
||||
if r.isTokenExpired(account) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取代理 URL
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
|
||||
// 获取账户信息(tier、project_id 等)
|
||||
loadResp, loadRaw, _ := client.LoadCodeAssist(ctx, accessToken)
|
||||
if loadRaw != nil {
|
||||
account.Extra["load_code_assist"] = loadRaw
|
||||
}
|
||||
if loadResp != nil {
|
||||
// 尝试从 API 获取 project_id
|
||||
if projectID == "" && loadResp.CloudAICompanionProject != "" {
|
||||
projectID = loadResp.CloudAICompanionProject
|
||||
account.Credentials["project_id"] = projectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果仍然没有 project_id,随机生成一个并保存
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
account.Credentials["project_id"] = projectID
|
||||
log.Printf("[AntigravityQuotaRefresher] 为账户 %d 生成随机 project_id: %s", account.ID, projectID)
|
||||
}
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
return r.accountRepo.Update(ctx, account) // 保存已有的 load_code_assist 信息
|
||||
}
|
||||
|
||||
// 保存完整的配额响应
|
||||
if modelsRaw != nil {
|
||||
account.Extra["available_models"] = modelsRaw
|
||||
}
|
||||
|
||||
// 解析配额数据为前端使用的格式
|
||||
r.updateAccountQuota(account, modelsResp)
|
||||
|
||||
account.Extra["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
// 保存到数据库
|
||||
return r.accountRepo.Update(ctx, account)
|
||||
}
|
||||
|
||||
// isTokenExpired 检查 token 是否过期
|
||||
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 提前 5 分钟认为过期
|
||||
return time.Now().Add(5 * time.Minute).After(*expiresAt)
|
||||
}
|
||||
|
||||
// updateAccountQuota 更新账户的配额信息(前端使用的格式)
|
||||
func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) {
|
||||
quota := make(map[string]any)
|
||||
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
|
||||
remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100)
|
||||
|
||||
quota[modelName] = map[string]any{
|
||||
"remaining": remaining,
|
||||
"reset_time": modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
}
|
||||
|
||||
account.Extra["quota"] = quota
|
||||
}
|
||||
21
backend/internal/service/quota_fetcher.go
Normal file
21
backend/internal/service/quota_fetcher.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// QuotaFetcher 额度获取接口,各平台实现此接口
|
||||
type QuotaFetcher interface {
|
||||
// CanFetch 检查是否可以获取此账户的额度
|
||||
CanFetch(account *Account) bool
|
||||
// GetProxyURL 获取账户的代理 URL(如果没有代理则返回空字符串)
|
||||
GetProxyURL(ctx context.Context, account *Account) (string, error)
|
||||
// FetchQuota 获取账户额度信息
|
||||
FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error)
|
||||
}
|
||||
|
||||
// QuotaResult 额度获取结果
|
||||
type QuotaResult struct {
|
||||
UsageInfo *UsageInfo // 转换后的使用信息
|
||||
Raw map[string]any // 原始响应,可存入 account.Extra
|
||||
}
|
||||
Reference in New Issue
Block a user