Merge branch 'main' into test-dev

This commit is contained in:
yangjianbo
2026-01-03 21:38:21 +08:00
21 changed files with 408 additions and 429 deletions

View File

@@ -70,7 +70,6 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
antigravityQuota *service.AntigravityQuotaRefresher,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -113,10 +112,6 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
{"AntigravityQuotaRefresher", func() error {
antigravityQuota.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},

View File

@@ -90,7 +90,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
@@ -145,8 +147,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -179,7 +180,6 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
antigravityQuota *service.AntigravityQuotaRefresher,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -221,10 +221,6 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
{"AntigravityQuotaRefresher", func() error {
antigravityQuota.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},

View File

@@ -588,8 +588,20 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
// Send error event in SSE format with proper JSON marshaling
errorData := map[string]any{
"type": "error",
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -740,8 +752,27 @@ func sendMockWarmupStream(c *gin.Context, model string) {
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_mock_warmup",
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": 0,
},
},
}
messageStartJSON, _ := json.Marshal(messageStart)
events := []string{
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,

View File

@@ -144,6 +144,21 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// Try immediate acquire first (avoid unnecessary wait)
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Determine if ping is needed (streaming + ping format defined)
needPing := isStream && h.pingFormat != ""

View File

@@ -143,9 +143,10 @@ type GeminiCandidate struct {
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
// DefaultSafetySettings 默认安全设置(关闭所有过滤)

View File

@@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking除非未来支持完整签名链路
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil {
@@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig
reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
generationConfig := buildGenerationConfig(claudeReq)
// 4. 构建 tools
tools := buildTools(claudeReq.Tools)
@@ -183,34 +172,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
@@ -239,30 +200,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
case "thinking":
if allowDummyThought {
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
part := GeminiPart{
Text: block.Thinking,
Thought: true,
}
// 保留原有 signatureClaude 模型需要有效的 signature
if block.Signature != "" {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
}
// Claude 模型:仅在提供有效 signature 时保留 thinking block否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
parts = append(parts, part)
case "image":
if block.Source != nil && block.Source.Type == "base64" {
@@ -433,7 +386,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
var funcDecls []GeminiFunctionDecl
for i, tool := range tools {
for _, tool := range tools {
// 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name")
@@ -452,10 +405,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else {
// 标准格式: 从顶层字段获取
description = tool.Description
@@ -472,11 +421,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name,
Description: description,
@@ -631,11 +575,9 @@ func cleanSchemaValue(value any) any {
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}

View File

@@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
{
Type: "custom",
Name: "mcp_tool",
Custom: &CustomToolSpec{
Custom: &ClaudeCustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
@@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
{
Type: "custom",
Name: "custom_tool",
Custom: &CustomToolSpec{
Custom: &ClaudeCustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
@@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
{
Type: "custom",
Name: "invalid_custom",
Custom: &CustomToolSpec{
Custom: &ClaudeCustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},

View File

@@ -232,10 +232,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
stopReason = "max_tokens"
}
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
usage := ClaudeUsage{}
if geminiResp.UsageMetadata != nil {
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount
cached := geminiResp.UsageMetadata.CachedContentTokenCount
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
// 生成响应 ID

View File

@@ -29,8 +29,9 @@ type StreamingProcessor struct {
originalModel string
// 累计 usage
inputTokens int
outputTokens int
inputTokens int
outputTokens int
cacheReadTokens int
}
// NewStreamingProcessor 创建流式响应处理器
@@ -76,9 +77,13 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
}
// 更新 usage
// 注意Gemini 的 promptTokenCount 包含 cachedContentTokenCount
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens需要减去
if geminiResp.UsageMetadata != nil {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount
cached := geminiResp.UsageMetadata.CachedContentTokenCount
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
p.cacheReadTokens = cached
}
// 处理 parts
@@ -108,8 +113,9 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
}
usage := &ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
return result.Bytes(), usage
@@ -123,8 +129,10 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage := ClaudeUsage{}
if v1Resp.Response.UsageMetadata != nil {
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
usage.CacheReadInputTokens = cached
}
responseID := v1Resp.ResponseID
@@ -418,8 +426,9 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
}
usage := ClaudeUsage{
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
}
deltaEvent := map[string]any{

View File

@@ -151,11 +151,17 @@ var (
return 1
`)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
@@ -163,6 +169,9 @@ var (
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID

View File

@@ -69,13 +69,29 @@ type windowStatsCache struct {
timestamp time.Time
}
var (
apiCacheMap = sync.Map{} // 缓存 API 响应
windowStatsCacheMap = sync.Map{} // 缓存窗口统计
// antigravityUsageCache 缓存 Antigravity 额度数据
type antigravityUsageCache struct {
usageInfo *UsageInfo
timestamp time.Time
}
const (
apiCacheTTL = 10 * time.Minute
windowStatsCacheTTL = 1 * time.Minute
)
// UsageCache 封装账户使用量相关的缓存
type UsageCache struct {
apiCache sync.Map // accountID -> *apiUsageCache
windowStatsCache sync.Map // accountID -> *windowStatsCache
antigravityCache sync.Map // accountID -> *antigravityUsageCache
}
// NewUsageCache 创建 UsageCache 实例
func NewUsageCache() *UsageCache {
return &UsageCache{}
}
// WindowStats 窗口期统计
type WindowStats struct {
Requests int64 `json:"requests"`
@@ -91,6 +107,12 @@ type UsageProgress struct {
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
}
// AntigravityModelQuota Antigravity 单个模型的配额信息
type AntigravityModelQuota struct {
Utilization int `json:"utilization"` // 使用率 0-100
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
@@ -99,6 +121,9 @@ type UsageInfo struct {
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -124,19 +149,30 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
antigravityQuotaFetcher *AntigravityQuotaFetcher
cache *UsageCache
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
func NewAccountUsageService(
accountRepo AccountRepository,
usageLogRepo UsageLogRepository,
usageFetcher ClaudeUsageFetcher,
geminiQuotaService *GeminiQuotaService,
antigravityQuotaFetcher *AntigravityQuotaFetcher,
cache *UsageCache,
) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
antigravityQuotaFetcher: antigravityQuotaFetcher,
cache: cache,
}
}
@@ -154,12 +190,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return s.getGeminiUsage(ctx, account)
}
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
if account.Platform == PlatformAntigravity {
return s.getAntigravityUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
// 1. 检查 API 缓存10 分钟)
if cached, ok := apiCacheMap.Load(accountID); ok {
if cached, ok := s.cache.apiCache.Load(accountID); ok {
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
apiResp = cache.response
}
@@ -172,7 +213,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, err
}
// 缓存 API 响应
apiCacheMap.Store(accountID, &apiUsageCache{
s.cache.apiCache.Store(accountID, &apiUsageCache{
response: apiResp,
timestamp: time.Now(),
})
@@ -230,6 +271,43 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
return usage, nil
}
// getAntigravityUsage 获取 Antigravity 账户额度
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
now := time.Now()
return &UsageInfo{UpdatedAt: &now}, nil
}
// 1. 检查缓存10 分钟)
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
// 重新计算 RemainingSeconds
usage := cache.usageInfo
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
}
return usage, nil
}
}
// 2. 获取代理 URL
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
// 3. 调用 API 获取额度
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
if err != nil {
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
}
// 4. 缓存结果
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
usageInfo: result.UsageInfo,
timestamp: time.Now(),
})
return result.UsageInfo, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -241,7 +319,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
// 检查窗口统计缓存1 分钟)
var windowStats *WindowStats
if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
windowStats = cache.stats
}
@@ -269,7 +347,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
}
// 缓存窗口统计1 分钟)
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
stats: windowStats,
timestamp: time.Now(),
})

View File

@@ -322,9 +322,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
if mappedModel != claudeReq.Model {
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name)
}
// 获取 access_token
if s.tokenProvider == nil {
@@ -350,15 +347,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {

View File

@@ -0,0 +1,111 @@
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
type AntigravityQuotaFetcher struct {
proxyRepo ProxyRepository
}
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
}
// CanFetch 检查是否可以获取此账户的额度
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
if account.Platform != PlatformAntigravity {
return false
}
accessToken := account.GetCredential("access_token")
return accessToken != ""
}
// FetchQuota 获取 Antigravity 账户额度信息
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
// 如果没有 project_id生成一个随机的
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
}
client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
return nil, err
}
// 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp)
return &QuotaResult{
UsageInfo: usageInfo,
Raw: modelsRaw,
}, nil
}
// buildUsageInfo 将 API 响应转换为 UsageInfo
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
}
// 遍历所有模型,填充 AntigravityQuota
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
}
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
utilization := int((1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100)
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
Utilization: utilization,
ResetTime: modelInfo.QuotaInfo.ResetTime,
}
}
// 同时设置 FiveHour 用于兼容展示(取主要模型)
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
for _, modelName := range priorityModels {
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
utilization := (1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100
progress := &UsageProgress{
Utilization: utilization,
}
if modelInfo.QuotaInfo.ResetTime != "" {
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
progress.ResetsAt = &resetTime
progress.RemainingSeconds = int(time.Until(resetTime).Seconds())
}
}
info.FiveHour = progress
break
}
}
return info
}
// GetProxyURL 获取账户的代理 URL
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) string {
if account.ProxyID == nil || f.proxyRepo == nil {
return ""
}
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
if err != nil || proxy == nil {
return ""
}
return proxy.URL()
}

View File

@@ -1,222 +0,0 @@
package service
import (
"context"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
type AntigravityQuotaRefresher struct {
accountRepo AccountRepository
proxyRepo ProxyRepository
cfg *config.TokenRefreshConfig
stopCh chan struct{}
wg sync.WaitGroup
}
// NewAntigravityQuotaRefresher 创建配额刷新器
func NewAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
_ *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
return &AntigravityQuotaRefresher{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
}
}
// Start 启动后台配额刷新服务
func (r *AntigravityQuotaRefresher) Start() {
if !r.cfg.Enabled {
log.Println("[AntigravityQuota] Service disabled by configuration")
return
}
r.wg.Add(1)
go r.refreshLoop()
log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes)
}
// Stop 停止服务
func (r *AntigravityQuotaRefresher) Stop() {
close(r.stopCh)
r.wg.Wait()
log.Println("[AntigravityQuota] Service stopped")
}
// refreshLoop 刷新循环
func (r *AntigravityQuotaRefresher) refreshLoop() {
defer r.wg.Done()
checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次
r.processRefresh()
for {
select {
case <-ticker.C:
r.processRefresh()
case <-r.stopCh:
return
}
}
}
// processRefresh 执行一次刷新
func (r *AntigravityQuotaRefresher) processRefresh() {
ctx := context.Background()
// 查询所有 active 的账户,然后过滤 antigravity 平台
allAccounts, err := r.accountRepo.ListActive(ctx)
if err != nil {
log.Printf("[AntigravityQuota] Failed to list accounts: %v", err)
return
}
// 过滤 antigravity 平台账户
var accounts []Account
for _, acc := range allAccounts {
if acc.Platform == PlatformAntigravity {
accounts = append(accounts, acc)
}
}
if len(accounts) == 0 {
return
}
refreshed, failed := 0, 0
for i := range accounts {
account := &accounts[i]
if err := r.refreshAccountQuota(ctx, account); err != nil {
log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err)
failed++
} else {
refreshed++
}
}
log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d",
len(accounts), refreshed, failed)
}
// refreshAccountQuota 刷新单个账户的配额
func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error {
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
if accessToken == "" {
return nil // 没有 access_token跳过
}
// token 过期则跳过,由 TokenRefreshService 负责刷新
if r.isTokenExpired(account) {
return nil
}
// 获取代理 URL
var proxyURL string
if account.ProxyID != nil {
proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client := antigravity.NewClient(proxyURL)
if account.Extra == nil {
account.Extra = make(map[string]any)
}
// 获取账户信息tier、project_id 等)
loadResp, loadRaw, _ := client.LoadCodeAssist(ctx, accessToken)
if loadRaw != nil {
account.Extra["load_code_assist"] = loadRaw
}
if loadResp != nil {
// 尝试从 API 获取 project_id
if projectID == "" && loadResp.CloudAICompanionProject != "" {
projectID = loadResp.CloudAICompanionProject
account.Credentials["project_id"] = projectID
}
}
// 如果仍然没有 project_id随机生成一个并保存
if projectID == "" {
projectID = antigravity.GenerateMockProjectID()
account.Credentials["project_id"] = projectID
log.Printf("[AntigravityQuotaRefresher] 为账户 %d 生成随机 project_id: %s", account.ID, projectID)
}
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
return r.accountRepo.Update(ctx, account) // 保存已有的 load_code_assist 信息
}
// 保存完整的配额响应
if modelsRaw != nil {
account.Extra["available_models"] = modelsRaw
}
// 解析配额数据为前端使用的格式
r.updateAccountQuota(account, modelsResp)
account.Extra["last_refresh"] = time.Now().Format(time.RFC3339)
// 保存到数据库
return r.accountRepo.Update(ctx, account)
}
// isTokenExpired 检查 token 是否过期
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
// 提前 5 分钟认为过期
return time.Now().Add(5 * time.Minute).After(*expiresAt)
}
// updateAccountQuota 更新账户的配额信息(前端使用的格式)
func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) {
quota := make(map[string]any)
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
}
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100)
quota[modelName] = map[string]any{
"remaining": remaining,
"reset_time": modelInfo.QuotaInfo.ResetTime,
}
}
account.Extra["quota"] = quota
}

View File

@@ -206,7 +206,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
@@ -431,7 +431,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)

View File

@@ -0,0 +1,19 @@
package service
import (
"context"
)
// QuotaFetcher 额度获取接口,各平台实现此接口
type QuotaFetcher interface {
// CanFetch 检查是否可以获取此账户的额度
CanFetch(account *Account) bool
// FetchQuota 获取账户额度信息
FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error)
}
// QuotaResult 额度获取结果
type QuotaResult struct {
UsageInfo *UsageInfo // 转换后的使用信息
Raw map[string]any // 原始响应,可存入 account.Extra
}

View File

@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService {
return svc
}
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
func ProvideAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
oauthSvc *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
svc.Start()
return svc
}
// ProvideDeferredService creates and starts DeferredService
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
@@ -124,6 +112,7 @@ var ProviderSet = wire.NewSet(
ProvideTokenRefreshService,
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
NewAntigravityQuotaFetcher,
NewUserAttributeService,
NewUsageCache,
)