Merge branch 'main' into test-dev
This commit is contained in:
@@ -70,7 +70,6 @@ func provideCleanup(
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
antigravityQuota *service.AntigravityQuotaRefresher,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -113,10 +112,6 @@ func provideCleanup(
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"AntigravityQuotaRefresher", func() error {
|
||||
antigravityQuota.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -90,7 +90,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
|
||||
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
|
||||
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
@@ -145,8 +147,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
|
||||
antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
|
||||
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
|
||||
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
|
||||
application := &Application{
|
||||
Server: httpServer,
|
||||
Cleanup: v,
|
||||
@@ -179,7 +180,6 @@ func provideCleanup(
|
||||
openaiOAuth *service.OpenAIOAuthService,
|
||||
geminiOAuth *service.GeminiOAuthService,
|
||||
antigravityOAuth *service.AntigravityOAuthService,
|
||||
antigravityQuota *service.AntigravityQuotaRefresher,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -221,10 +221,6 @@ func provideCleanup(
|
||||
antigravityOAuth.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"AntigravityQuotaRefresher", func() error {
|
||||
antigravityQuota.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"Redis", func() error {
|
||||
return rdb.Close()
|
||||
}},
|
||||
|
||||
@@ -588,8 +588,20 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in SSE format
|
||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
// Send error event in SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -740,8 +752,27 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_mock_warmup",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
|
||||
events := []string{
|
||||
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
|
||||
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
|
||||
@@ -144,6 +144,21 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try immediate acquire first (avoid unnecessary wait)
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
|
||||
@@ -143,9 +143,10 @@ type GeminiCandidate struct {
|
||||
|
||||
// GeminiUsageMetadata Gemini 用量元数据
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultSafetySettings 默认安全设置(关闭所有过滤)
|
||||
|
||||
@@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
||||
// 用于存储 tool_use id -> name 映射
|
||||
toolIDToName := make(map[string]string)
|
||||
|
||||
// 检测是否启用 thinking
|
||||
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
|
||||
// 只有 Gemini 模型支持 dummy thought workaround
|
||||
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
|
||||
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
|
||||
|
||||
// 检测是否启用 thinking
|
||||
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
||||
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
|
||||
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
|
||||
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
|
||||
|
||||
// 1. 构建 contents
|
||||
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
|
||||
if err != nil {
|
||||
@@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
|
||||
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
|
||||
|
||||
// 3. 构建 generationConfig
|
||||
reqForGen := claudeReq
|
||||
if requestedThinkingEnabled && !allowDummyThought {
|
||||
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
|
||||
// shallow copy to avoid mutating caller's request
|
||||
clone := *claudeReq
|
||||
clone.Thinking = nil
|
||||
reqForGen = &clone
|
||||
}
|
||||
generationConfig := buildGenerationConfig(reqForGen)
|
||||
generationConfig := buildGenerationConfig(claudeReq)
|
||||
|
||||
// 4. 构建 tools
|
||||
tools := buildTools(claudeReq.Tools)
|
||||
@@ -183,34 +172,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
|
||||
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const dummyThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
// isValidThoughtSignature 验证 thought signature 是否有效
|
||||
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
|
||||
func isValidThoughtSignature(signature string) bool {
|
||||
// 空字符串无效
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
|
||||
// 参考 Claude API 文档和实际观察到的有效 signature
|
||||
if len(signature) < 40 {
|
||||
log.Printf("[Debug] Signature too short: len=%d", len(signature))
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否是有效的 base64 字符
|
||||
// base64 字符集: A-Z, a-z, 0-9, +, /, =
|
||||
for i, c := range signature {
|
||||
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
|
||||
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
|
||||
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// buildParts 构建消息的 parts
|
||||
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
|
||||
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
|
||||
@@ -239,30 +200,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
|
||||
}
|
||||
|
||||
case "thinking":
|
||||
if allowDummyThought {
|
||||
// Gemini 模型可以使用 dummy signature
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
ThoughtSignature: dummyThoughtSignature,
|
||||
})
|
||||
part := GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
}
|
||||
// 保留原有 signature(Claude 模型需要有效的 signature)
|
||||
if block.Signature != "" {
|
||||
part.ThoughtSignature = block.Signature
|
||||
} else if !allowDummyThought {
|
||||
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
|
||||
log.Printf("Warning: skipping thinking block without signature for Claude model")
|
||||
continue
|
||||
} else {
|
||||
// Gemini 模型使用 dummy signature
|
||||
part.ThoughtSignature = dummyThoughtSignature
|
||||
}
|
||||
|
||||
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
|
||||
signature := strings.TrimSpace(block.Signature)
|
||||
if signature == "" || signature == dummyThoughtSignature {
|
||||
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
|
||||
continue
|
||||
}
|
||||
if !isValidThoughtSignature(signature) {
|
||||
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: block.Thinking,
|
||||
Thought: true,
|
||||
ThoughtSignature: signature,
|
||||
})
|
||||
parts = append(parts, part)
|
||||
|
||||
case "image":
|
||||
if block.Source != nil && block.Source.Type == "base64" {
|
||||
@@ -433,7 +386,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
|
||||
// 普通工具
|
||||
var funcDecls []GeminiFunctionDecl
|
||||
for i, tool := range tools {
|
||||
for _, tool := range tools {
|
||||
// 跳过无效工具名称
|
||||
if strings.TrimSpace(tool.Name) == "" {
|
||||
log.Printf("Warning: skipping tool with empty name")
|
||||
@@ -452,10 +405,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
description = tool.Custom.Description
|
||||
inputSchema = tool.Custom.InputSchema
|
||||
|
||||
// 调试日志:记录 custom 工具的 schema
|
||||
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
|
||||
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
|
||||
}
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
description = tool.Description
|
||||
@@ -472,11 +421,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
|
||||
}
|
||||
}
|
||||
|
||||
// 调试日志:记录清理后的 schema
|
||||
if paramsJSON, err := json.Marshal(params); err == nil {
|
||||
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
|
||||
}
|
||||
|
||||
funcDecls = append(funcDecls, GeminiFunctionDecl{
|
||||
Name: tool.Name,
|
||||
Description: description,
|
||||
@@ -631,11 +575,9 @@ func cleanSchemaValue(value any) any {
|
||||
if k == "additionalProperties" {
|
||||
if boolVal, ok := val.(bool); ok {
|
||||
result[k] = boolVal
|
||||
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
|
||||
} else {
|
||||
// 如果是 schema 对象,转换为 false(更安全的默认值)
|
||||
result[k] = false
|
||||
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "mcp_tool",
|
||||
Custom: &CustomToolSpec{
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "MCP tool description",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
@@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "custom_tool",
|
||||
Custom: &CustomToolSpec{
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Custom tool",
|
||||
InputSchema: map[string]any{"type": "object"},
|
||||
},
|
||||
@@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
||||
{
|
||||
Type: "custom",
|
||||
Name: "invalid_custom",
|
||||
Custom: &CustomToolSpec{
|
||||
Custom: &ClaudeCustomToolSpec{
|
||||
Description: "Invalid",
|
||||
// InputSchema 为 nil
|
||||
},
|
||||
|
||||
@@ -232,10 +232,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
|
||||
stopReason = "max_tokens"
|
||||
}
|
||||
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
usage := ClaudeUsage{}
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
// 生成响应 ID
|
||||
|
||||
@@ -29,8 +29,9 @@ type StreamingProcessor struct {
|
||||
originalModel string
|
||||
|
||||
// 累计 usage
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheReadTokens int
|
||||
}
|
||||
|
||||
// NewStreamingProcessor 创建流式响应处理器
|
||||
@@ -76,9 +77,13 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
}
|
||||
|
||||
// 更新 usage
|
||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||
if geminiResp.UsageMetadata != nil {
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount
|
||||
cached := geminiResp.UsageMetadata.CachedContentTokenCount
|
||||
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
|
||||
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount
|
||||
p.cacheReadTokens = cached
|
||||
}
|
||||
|
||||
// 处理 parts
|
||||
@@ -108,8 +113,9 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
return result.Bytes(), usage
|
||||
@@ -123,8 +129,10 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
||||
|
||||
usage := ClaudeUsage{}
|
||||
if v1Resp.Response.UsageMetadata != nil {
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount
|
||||
cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount
|
||||
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
|
||||
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount
|
||||
usage.CacheReadInputTokens = cached
|
||||
}
|
||||
|
||||
responseID := v1Resp.ResponseID
|
||||
@@ -418,8 +426,9 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
}
|
||||
|
||||
usage := ClaudeUsage{
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
InputTokens: p.inputTokens,
|
||||
OutputTokens: p.outputTokens,
|
||||
CacheReadInputTokens: p.cacheReadTokens,
|
||||
}
|
||||
|
||||
deltaEvent := map[string]any{
|
||||
|
||||
@@ -151,11 +151,17 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getAccountsLoadBatchScript - batch load query (read-only)
|
||||
// ARGV[1] = slot TTL (seconds, retained for compatibility)
|
||||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||
getAccountsLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
@@ -163,6 +169,9 @@ var (
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:account:' .. accountID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'wait:account:' .. accountID
|
||||
|
||||
@@ -69,13 +69,29 @@ type windowStatsCache struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
apiCacheMap = sync.Map{} // 缓存 API 响应
|
||||
windowStatsCacheMap = sync.Map{} // 缓存窗口统计
|
||||
// antigravityUsageCache 缓存 Antigravity 额度数据
|
||||
type antigravityUsageCache struct {
|
||||
usageInfo *UsageInfo
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 10 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
func NewUsageCache() *UsageCache {
|
||||
return &UsageCache{}
|
||||
}
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
type WindowStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
@@ -91,6 +107,12 @@ type UsageProgress struct {
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
}
|
||||
|
||||
// AntigravityModelQuota Antigravity 单个模型的配额信息
|
||||
type AntigravityModelQuota struct {
|
||||
Utilization int `json:"utilization"` // 使用率 0-100
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
@@ -99,6 +121,9 @@ type UsageInfo struct {
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -124,19 +149,30 @@ type ClaudeUsageFetcher interface {
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||
cache *UsageCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
|
||||
func NewAccountUsageService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageFetcher ClaudeUsageFetcher,
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,12 +190,17 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return s.getGeminiUsage(ctx, account)
|
||||
}
|
||||
|
||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.getAntigravityUsage(ctx, account)
|
||||
}
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
if account.CanGetUsage() {
|
||||
var apiResp *ClaudeUsageResponse
|
||||
|
||||
// 1. 检查 API 缓存(10 分钟)
|
||||
if cached, ok := apiCacheMap.Load(accountID); ok {
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
}
|
||||
@@ -172,7 +213,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, err
|
||||
}
|
||||
// 缓存 API 响应
|
||||
apiCacheMap.Store(accountID, &apiUsageCache{
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: apiResp,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
@@ -230,6 +271,43 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// getAntigravityUsage 获取 Antigravity 账户额度
|
||||
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
@@ -241,7 +319,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
|
||||
// 检查窗口统计缓存(1 分钟)
|
||||
var windowStats *WindowStats
|
||||
if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
|
||||
if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
|
||||
windowStats = cache.stats
|
||||
}
|
||||
@@ -269,7 +347,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
// 缓存窗口统计(1 分钟)
|
||||
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
|
||||
s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
|
||||
stats: windowStats,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
|
||||
@@ -322,9 +322,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
originalModel := claudeReq.Model
|
||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||
if mappedModel != claudeReq.Model {
|
||||
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name)
|
||||
}
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -350,15 +347,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
|
||||
// 调试:记录转换后的请求体(仅记录前 2000 字符)
|
||||
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
|
||||
truncated := string(bodyJSON)
|
||||
if len(truncated) > 2000 {
|
||||
truncated = truncated[:2000] + "..."
|
||||
}
|
||||
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
|
||||
}
|
||||
|
||||
// 构建上游 action
|
||||
action := "generateContent"
|
||||
if claudeReq.Stream {
|
||||
|
||||
111
backend/internal/service/antigravity_quota_fetcher.go
Normal file
111
backend/internal/service/antigravity_quota_fetcher.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
|
||||
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
|
||||
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
|
||||
}
|
||||
|
||||
// CanFetch 检查是否可以获取此账户的额度
|
||||
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
accessToken := account.GetCredential("access_token")
|
||||
return accessToken != ""
|
||||
}
|
||||
|
||||
// FetchQuota 获取 Antigravity 账户额度信息
|
||||
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
// 如果没有 project_id,生成一个随机的
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
Raw: modelsRaw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
|
||||
utilization := int((1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100)
|
||||
|
||||
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
|
||||
for _, modelName := range priorityModels {
|
||||
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
|
||||
utilization := (1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100
|
||||
progress := &UsageProgress{
|
||||
Utilization: utilization,
|
||||
}
|
||||
if modelInfo.QuotaInfo.ResetTime != "" {
|
||||
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
|
||||
progress.ResetsAt = &resetTime
|
||||
progress.RemainingSeconds = int(time.Until(resetTime).Seconds())
|
||||
}
|
||||
}
|
||||
info.FiveHour = progress
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GetProxyURL 获取账户的代理 URL
|
||||
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) string {
|
||||
if account.ProxyID == nil || f.proxyRepo == nil {
|
||||
return ""
|
||||
}
|
||||
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err != nil || proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
|
||||
type AntigravityQuotaRefresher struct {
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
cfg *config.TokenRefreshConfig
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaRefresher 创建配额刷新器
|
||||
func NewAntigravityQuotaRefresher(
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
_ *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *AntigravityQuotaRefresher {
|
||||
return &AntigravityQuotaRefresher{
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动后台配额刷新服务
|
||||
func (r *AntigravityQuotaRefresher) Start() {
|
||||
if !r.cfg.Enabled {
|
||||
log.Println("[AntigravityQuota] Service disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
r.wg.Add(1)
|
||||
go r.refreshLoop()
|
||||
|
||||
log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes)
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (r *AntigravityQuotaRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
log.Println("[AntigravityQuota] Service stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
func (r *AntigravityQuotaRefresher) refreshLoop() {
|
||||
defer r.wg.Done()
|
||||
|
||||
checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute
|
||||
if checkInterval < time.Minute {
|
||||
checkInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时立即执行一次
|
||||
r.processRefresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.processRefresh()
|
||||
case <-r.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processRefresh 执行一次刷新
|
||||
func (r *AntigravityQuotaRefresher) processRefresh() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 查询所有 active 的账户,然后过滤 antigravity 平台
|
||||
allAccounts, err := r.accountRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[AntigravityQuota] Failed to list accounts: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤 antigravity 平台账户
|
||||
var accounts []Account
|
||||
for _, acc := range allAccounts {
|
||||
if acc.Platform == PlatformAntigravity {
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
refreshed, failed := 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
if err := r.refreshAccountQuota(ctx, account); err != nil {
|
||||
log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
failed++
|
||||
} else {
|
||||
refreshed++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d",
|
||||
len(accounts), refreshed, failed)
|
||||
}
|
||||
|
||||
// refreshAccountQuota 刷新单个账户的配额
|
||||
func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
if accessToken == "" {
|
||||
return nil // 没有 access_token,跳过
|
||||
}
|
||||
|
||||
// token 过期则跳过,由 TokenRefreshService 负责刷新
|
||||
if r.isTokenExpired(account) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取代理 URL
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
|
||||
// 获取账户信息(tier、project_id 等)
|
||||
loadResp, loadRaw, _ := client.LoadCodeAssist(ctx, accessToken)
|
||||
if loadRaw != nil {
|
||||
account.Extra["load_code_assist"] = loadRaw
|
||||
}
|
||||
if loadResp != nil {
|
||||
// 尝试从 API 获取 project_id
|
||||
if projectID == "" && loadResp.CloudAICompanionProject != "" {
|
||||
projectID = loadResp.CloudAICompanionProject
|
||||
account.Credentials["project_id"] = projectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果仍然没有 project_id,随机生成一个并保存
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
account.Credentials["project_id"] = projectID
|
||||
log.Printf("[AntigravityQuotaRefresher] 为账户 %d 生成随机 project_id: %s", account.ID, projectID)
|
||||
}
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
return r.accountRepo.Update(ctx, account) // 保存已有的 load_code_assist 信息
|
||||
}
|
||||
|
||||
// 保存完整的配额响应
|
||||
if modelsRaw != nil {
|
||||
account.Extra["available_models"] = modelsRaw
|
||||
}
|
||||
|
||||
// 解析配额数据为前端使用的格式
|
||||
r.updateAccountQuota(account, modelsResp)
|
||||
|
||||
account.Extra["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
// 保存到数据库
|
||||
return r.accountRepo.Update(ctx, account)
|
||||
}
|
||||
|
||||
// isTokenExpired 检查 token 是否过期
|
||||
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 提前 5 分钟认为过期
|
||||
return time.Now().Add(5 * time.Minute).After(*expiresAt)
|
||||
}
|
||||
|
||||
// updateAccountQuota 更新账户的配额信息(前端使用的格式)
|
||||
func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) {
|
||||
quota := make(map[string]any)
|
||||
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
|
||||
remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100)
|
||||
|
||||
quota[modelName] = map[string]any{
|
||||
"remaining": remaining,
|
||||
"reset_time": modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
}
|
||||
|
||||
account.Extra["quota"] = quota
|
||||
}
|
||||
@@ -206,7 +206,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
if sessionHash == "" || accountID <= 0 || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||
@@ -431,7 +431,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
|
||||
19
backend/internal/service/quota_fetcher.go
Normal file
19
backend/internal/service/quota_fetcher.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// QuotaFetcher 额度获取接口,各平台实现此接口
|
||||
type QuotaFetcher interface {
|
||||
// CanFetch 检查是否可以获取此账户的额度
|
||||
CanFetch(account *Account) bool
|
||||
// FetchQuota 获取账户额度信息
|
||||
FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error)
|
||||
}
|
||||
|
||||
// QuotaResult 额度获取结果
|
||||
type QuotaResult struct {
|
||||
UsageInfo *UsageInfo // 转换后的使用信息
|
||||
Raw map[string]any // 原始响应,可存入 account.Extra
|
||||
}
|
||||
@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService {
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
|
||||
func ProvideAntigravityQuotaRefresher(
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
oauthSvc *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *AntigravityQuotaRefresher {
|
||||
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideDeferredService creates and starts DeferredService
|
||||
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
|
||||
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
|
||||
@@ -124,6 +112,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideTokenRefreshService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDeferredService,
|
||||
ProvideAntigravityQuotaRefresher,
|
||||
NewAntigravityQuotaFetcher,
|
||||
NewUserAttributeService,
|
||||
NewUsageCache,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user