Merge branch 'main' into test-dev

This commit is contained in:
yangjianbo
2025-12-30 09:07:55 +08:00
61 changed files with 7785 additions and 262 deletions

View File

@@ -346,3 +346,20 @@ func (a *Account) IsOpenAITokenExpired() bool {
}
return time.Now().Add(60 * time.Second).After(*expiresAt)
}
// IsMixedSchedulingEnabled 检查 antigravity 账户是否启用混合调度
// 启用后可参与 anthropic/gemini 分组的账户调度
func (a *Account) IsMixedSchedulingEnabled() bool {
if a.Platform != PlatformAntigravity {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["mixed_scheduling"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}

View File

@@ -40,6 +40,8 @@ type AccountRepository interface {
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error

View File

@@ -0,0 +1,823 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
antigravityStickySessionTTL = time.Hour
antigravityMaxRetries = 5
antigravityRetryBaseDelay = 1 * time.Second
antigravityRetryMaxDelay = 16 * time.Second
)
// Antigravity 直接支持的模型
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
"claude-sonnet-4-5-thinking": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
"gemini-2.5-flash-thinking": true,
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true,
}
// Antigravity 系统默认模型映射表(不支持 → 支持)
var antigravityModelMapping = map[string]string{
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
"claude-opus-4": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-haiku-4": "gemini-3-flash",
"claude-haiku-4-5": "gemini-3-flash",
"claude-3-haiku-20240307": "gemini-3-flash",
"claude-haiku-4-5-20251001": "gemini-3-flash",
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview": "gemini-3-pro-image",
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
type AntigravityGatewayService struct {
accountRepo AccountRepository
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
}
func NewAntigravityGatewayService(
accountRepo AccountRepository,
_ GatewayCache,
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
}
}
// GetTokenProvider 返回 token provider
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
return s.tokenProvider
}
// getMappedModel 获取映射后的模型名
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法)
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
}
// 2. 系统默认映射
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
return mapped
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
}
// 4. Claude 前缀透传直接支持的模型
if antigravitySupportedModels[requestedModel] {
return requestedModel
}
// 5. 默认值
return "claude-sonnet-4-5"
}
// IsModelSupported 检查模型是否被支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
}
// wrapV1InternalRequest 包装请求为 v1internal 格式
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
var request any
if err := json.Unmarshal(originalBody, &request); err != nil {
return nil, fmt.Errorf("解析请求体失败: %w", err)
}
wrapped := map[string]any{
"project": projectID,
"requestId": "agent-" + uuid.New().String(),
"userAgent": "sub2api",
"requestType": "agent",
"model": model,
"request": request,
}
return json.Marshal(wrapped)
}
// unwrapV1InternalResponse 解包 v1internal 响应
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
var outer map[string]any
if err := json.Unmarshal(body, &outer); err != nil {
return nil, err
}
if resp, ok := outer["response"]; ok {
return json.Marshal(resp)
}
return body, nil
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
// 解析 Claude 请求
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
mappedModel := s.getMappedModel(account, claudeReq.Model)
if mappedModel != claudeReq.Model {
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name)
}
// 获取 access_token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 转换 Claude 请求为 Gemini 格式
geminiBody, err := antigravity.TransformClaudeToGemini(&claudeReq, projectID, mappedModel)
if err != nil {
return nil, fmt.Errorf("transform request: %w", err)
}
// 构建上游 URL
action := "generateContent"
if claudeReq.Stream {
action = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
if claudeReq.Stream {
fullURL += "?alt=sse"
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiBody))
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
continue
}
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
continue
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
// 最后一次尝试也失败
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
}
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usage, err = s.handleClaudeNonStreamingResponse(c, resp, originalModel)
if err != nil {
return nil, err
}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: claudeReq.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
}
if strings.TrimSpace(action) == "" {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
}
if len(body) == 0 {
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
// ok
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
mappedModel := s.getMappedModel(account, originalModel)
// 获取 access_token
if s.tokenProvider == nil {
return nil, errors.New("antigravity token provider not configured")
}
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
}
// 获取 project_id
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID == "" {
return nil, errors.New("project_id not found in credentials")
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
if err != nil {
return nil, err
}
// 构建上游 URL
upstreamAction := action
if action == "generateContent" && stream {
upstreamAction = "streamGenerateContent"
}
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
if stream || upstreamAction == "streamGenerateContent" {
fullURL += "?alt=sse"
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
if err != nil {
return nil, err
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
sleepAntigravityBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
sleepAntigravityBackoff(attempt)
continue
}
// 所有重试都失败,标记限流状态
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(respBody)),
}
break
}
break
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 解包并返回错误
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, unwrapped)
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
var usage *ClaudeUsage
var firstTokenMs *int
if stream || upstreamAction == "streamGenerateContent" {
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
if err != nil {
return nil, err
}
usage = streamRes.usage
firstTokenMs = streamRes.firstTokenMs
} else {
usageResp, err := s.handleGeminiNonStreamingResponse(c, resp)
if err != nil {
return nil, err
}
usage = usageResp
}
if usage == nil {
usage = &ClaudeUsage{}
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
Model: originalModel,
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
default:
return false
}
}
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func sleepAntigravityBackoff(attempt int) {
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
}
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if statusCode == 429 {
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
// 解析失败Gemini 有重试时间用 5 分钟Claude 没有用 1 分钟
defaultDur := 1 * time.Minute
if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) {
defaultDur = 5 * time.Minute
}
ra := time.Now().Add(defaultDur)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
return
}
// 其他错误码继续使用 rateLimitService
if s.rateLimitService == nil {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
}
type antigravityStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream; charset=utf-8"
}
c.Header("Content-Type", contentType)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
reader := bufio.NewReader(resp.Body)
usage := &ClaudeUsage{}
var firstTokenMs *int
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
trimmed := strings.TrimRight(line, "\r\n")
if strings.HasPrefix(trimmed, "data:") {
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if payload == "" || payload == "[DONE]" {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
} else {
// 解包 v1internal 响应
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
if parseErr == nil && inner != nil {
payload = string(inner)
}
// 解析 usage
var parsed map[string]any
if json.Unmarshal(inner, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
flusher.Flush()
}
} else {
_, _ = io.WriteString(c.Writer, line)
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解包 v1internal 响应
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
var parsed map[string]any
if json.Unmarshal(unwrapped, &parsed) == nil {
if u := extractGeminiUsage(parsed); u != nil {
c.Data(resp.StatusCode, "application/json", unwrapped)
return u, nil
}
}
c.Data(resp.StatusCode, "application/json", unwrapped)
return &ClaudeUsage{}, nil
}
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": message},
})
return fmt.Errorf("%s", message)
}
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
// 记录上游错误详情便于调试
log.Printf("Antigravity upstream error %d: %s", upstreamStatus, string(body))
var statusCode int
var errType, errMsg string
switch upstreamStatus {
case 400:
statusCode = http.StatusBadRequest
errType = "invalid_request_error"
errMsg = "Invalid request"
case 401:
statusCode = http.StatusBadGateway
errType = "authentication_error"
errMsg = "Upstream authentication failed"
case 403:
statusCode = http.StatusBadGateway
errType = "permission_error"
errMsg = "Upstream access forbidden"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded"
case 529:
statusCode = http.StatusServiceUnavailable
errType = "overloaded_error"
errMsg = "Upstream service overloaded"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
return fmt.Errorf("upstream error: %d", upstreamStatus)
}
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
statusStr := "UNKNOWN"
switch status {
case 400:
statusStr = "INVALID_ARGUMENT"
case 404:
statusStr = "NOT_FOUND"
case 429:
statusStr = "RESOURCE_EXHAUSTED"
case 500:
statusStr = "INTERNAL"
case 502, 503:
statusStr = "UNAVAILABLE"
}
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": statusStr,
},
})
return fmt.Errorf("%s", message)
}
// handleClaudeNonStreamingResponse 处理 Claude 非流式响应Gemini → Claude 转换)
func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
}
// 转换 Gemini 响应为 Claude 格式
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(body, originalModel)
if err != nil {
log.Printf("Transform Gemini to Claude failed: %v, body: %s", err, string(body))
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
}
c.Data(http.StatusOK, "application/json", claudeResp)
// 转换为 service.ClaudeUsage
usage := &ClaudeUsage{
InputTokens: agUsage.InputTokens,
OutputTokens: agUsage.OutputTokens,
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
return usage, nil
}
// handleClaudeStreamingResponse 处理 Claude 流式响应Gemini SSE → Claude SSE 转换)
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
processor := antigravity.NewStreamingProcessor(originalModel)
var firstTokenMs *int
reader := bufio.NewReader(resp.Body)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
if agUsage == nil {
return &ClaudeUsage{}
}
return &ClaudeUsage{
InputTokens: agUsage.InputTokens,
OutputTokens: agUsage.OutputTokens,
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
CacheReadInputTokens: agUsage.CacheReadInputTokens,
}
}
for {
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("stream read error: %w", err)
}
if len(line) > 0 {
// 处理 SSE 行,转换为 Claude 格式
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
if len(claudeEvents) > 0 {
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil {
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
}
flusher.Flush()
}
}
if errors.Is(err, io.EOF) {
break
}
}
// 发送结束事件
finalEvents, agUsage := processor.Finish()
if len(finalEvents) > 0 {
_, _ = c.Writer.Write(finalEvents)
flusher.Flush()
}
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
}

View File

@@ -0,0 +1,269 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsAntigravityModelSupported(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// 直接支持的模型
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
// 可映射的模型
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
{"可映射 - claude-opus-4", "claude-opus-4", true},
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
// Claude 前缀兜底
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
{"Claude前缀 - claude-future-version", "claude-future-version", true},
// 不支持的模型
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - gpt-4o", "gpt-4o", false},
{"不支持 - llama-3", "llama-3", false},
{"不支持 - mistral-7b", "mistral-7b", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsAntigravityModelSupported(tt.model)
require.Equal(t, tt.expected, got, "model: %s", tt.model)
})
}
}
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先注意model_mapping 在 credentials 中存储为 map[string]any
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
expected: "custom-model",
},
{
name: "账户映射覆盖系统映射",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 系统默认映射
{
name: "系统映射 - claude-3-5-sonnet-20241022",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-5-sonnet-20240620",
requestedModel: "claude-3-5-sonnet-20240620",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-opus-4",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-opus-4-5-20251101",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-haiku-4 → gemini-3-flash",
requestedModel: "claude-haiku-4",
accountMapping: nil,
expected: "gemini-3-flash",
},
{
name: "系统映射 - claude-haiku-4-5 → gemini-3-flash",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "gemini-3-flash",
},
{
name: "系统映射 - claude-3-haiku-20240307 → gemini-3-flash",
requestedModel: "claude-3-haiku-20240307",
accountMapping: nil,
expected: "gemini-3-flash",
},
{
name: "系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "gemini-3-flash",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
// 3. Gemini 透传
{
name: "Gemini透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-2.5-flash",
},
{
name: "Gemini透传 - gemini-1.5-pro",
requestedModel: "gemini-1.5-pro",
accountMapping: nil,
expected: "gemini-1.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "gemini-future-model",
},
// 4. 直接支持的模型
{
name: "直接支持 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "直接支持 - claude-opus-4-5-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-5-thinking",
},
{
name: "直接支持 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
// 5. 默认值 fallback未知 claude 模型)
{
name: "默认值 - claude-unknown",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认值 - claude-3-opus-20240229",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
}
if tt.accountMapping != nil {
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
mappingAny := make(map[string]any)
for k, v := range tt.accountMapping {
mappingAny[k] = v
}
account.Credentials = map[string]any{
"model_mapping": mappingAny,
}
}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
})
}
}
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
expected string
}{
// 空字符串回退到默认值
{"空字符串", "", "claude-sonnet-4-5"},
// 非 claude/gemini 前缀回退到默认值
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: PlatformAntigravity}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got)
})
}
}
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
model string
expected bool
}{
// 直接支持
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射
{"可映射 - claude-opus-4", "claude-opus-4", true},
// 前缀透传
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
// 不支持
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.IsModelSupported(tt.model)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -0,0 +1,267 @@
package service
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
type AntigravityOAuthService struct {
sessionStore *antigravity.SessionStore
proxyRepo ProxyRepository
}
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
return &AntigravityOAuthService{
sessionStore: antigravity.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
// AntigravityAuthURLResult is the result of generating an authorization URL
type AntigravityAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
// GenerateAuthURL 生成 Google OAuth 授权链接
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
state, err := antigravity.GenerateState()
if err != nil {
return nil, fmt.Errorf("生成 state 失败: %w", err)
}
codeVerifier, err := antigravity.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
}
sessionID, err := antigravity.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
session := &antigravity.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
return &AntigravityAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
// AntigravityExchangeCodeInput 交换 code 的输入
type AntigravityExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
}
// AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
}
// ExchangeCode 用 authorization code 交换 token
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session 不存在或已过期")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state 无效")
}
// 确定代理 URL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client := antigravity.NewClient(proxyURL)
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("token 交换失败: %w", err)
}
// 删除 session
s.sessionStore.Delete(input.SessionID)
// 计算过期时间(减去 5 分钟安全窗口)
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
result := &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}
// 获取用户信息
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
result.Email = userInfo.Email
}
// 获取 project_id
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
result.ProjectID = loadResp.CloudAICompanionProject
}
return result, nil
}
// RefreshToken 刷新 token
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
client := antigravity.NewClient(proxyURL)
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
return &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}, nil
}
if isNonRetryableAntigravityOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
}
func isNonRetryableAntigravityOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
// RefreshAccountToken 刷新账户的 token
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("无可用的 refresh_token")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 保留原有的 project_id 和 email
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
if existingProjectID != "" {
tokenInfo.ProjectID = existingProjectID
}
existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" {
tokenInfo.Email = existingEmail
}
return tokenInfo, nil
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
return creds
}
// Stop 停止服务
func (s *AntigravityOAuthService) Stop() {
s.sessionStore.Stop()
}

View File

@@ -0,0 +1,225 @@
package service
import (
"context"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
type AntigravityQuotaRefresher struct {
accountRepo AccountRepository
proxyRepo ProxyRepository
cfg *config.TokenRefreshConfig
stopCh chan struct{}
wg sync.WaitGroup
}
// NewAntigravityQuotaRefresher 创建配额刷新器
func NewAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
_ *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
return &AntigravityQuotaRefresher{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
}
}
// Start 启动后台配额刷新服务
func (r *AntigravityQuotaRefresher) Start() {
if !r.cfg.Enabled {
log.Println("[AntigravityQuota] Service disabled by configuration")
return
}
r.wg.Add(1)
go r.refreshLoop()
log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes)
}
// Stop 停止服务
func (r *AntigravityQuotaRefresher) Stop() {
close(r.stopCh)
r.wg.Wait()
log.Println("[AntigravityQuota] Service stopped")
}
// refreshLoop 刷新循环
func (r *AntigravityQuotaRefresher) refreshLoop() {
defer r.wg.Done()
checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次
r.processRefresh()
for {
select {
case <-ticker.C:
r.processRefresh()
case <-r.stopCh:
return
}
}
}
// processRefresh 执行一次刷新
func (r *AntigravityQuotaRefresher) processRefresh() {
ctx := context.Background()
// 查询所有 active 的账户,然后过滤 antigravity 平台
allAccounts, err := r.accountRepo.ListActive(ctx)
if err != nil {
log.Printf("[AntigravityQuota] Failed to list accounts: %v", err)
return
}
// 过滤 antigravity 平台账户
var accounts []Account
for _, acc := range allAccounts {
if acc.Platform == PlatformAntigravity {
accounts = append(accounts, acc)
}
}
if len(accounts) == 0 {
return
}
refreshed, failed := 0, 0
for i := range accounts {
account := &accounts[i]
if err := r.refreshAccountQuota(ctx, account); err != nil {
log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err)
failed++
} else {
refreshed++
}
}
log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d",
len(accounts), refreshed, failed)
}
// refreshAccountQuota 刷新单个账户的配额
func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error {
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
if accessToken == "" || projectID == "" {
return nil // 没有有效凭证,跳过
}
// token 过期则跳过,由 TokenRefreshService 负责刷新
if r.isTokenExpired(account) {
return nil
}
// 获取代理 URL
var proxyURL string
if account.ProxyID != nil {
proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client := antigravity.NewClient(proxyURL)
// 获取账户类型tier
loadResp, _ := client.LoadCodeAssist(ctx, accessToken)
if loadResp != nil {
r.updateAccountTier(account, loadResp)
}
// 调用 API 获取配额
modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
return err
}
// 解析配额数据并更新 extra 字段
r.updateAccountQuota(account, modelsResp)
// 保存到数据库
return r.accountRepo.Update(ctx, account)
}
// isTokenExpired 检查 token 是否过期
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
expiresAt := parseAntigravityExpiresAt(account)
if expiresAt == nil {
return false
}
// 提前 5 分钟认为过期
return time.Now().Add(5 * time.Minute).After(*expiresAt)
}
// updateAccountTier 更新账户类型信息
func (r *AntigravityQuotaRefresher) updateAccountTier(account *Account, loadResp *antigravity.LoadCodeAssistResponse) {
if account.Extra == nil {
account.Extra = make(map[string]any)
}
tier := loadResp.GetTier()
if tier != "" {
account.Extra["tier"] = tier
}
// 保存不符合条件的原因(如 INELIGIBLE_ACCOUNT
if len(loadResp.IneligibleTiers) > 0 && loadResp.IneligibleTiers[0] != nil {
ineligible := loadResp.IneligibleTiers[0]
if ineligible.ReasonCode != "" {
account.Extra["ineligible_reason_code"] = ineligible.ReasonCode
}
if ineligible.ReasonMessage != "" {
account.Extra["ineligible_reason_message"] = ineligible.ReasonMessage
}
}
}
// updateAccountQuota 更新账户的配额信息
func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) {
if account.Extra == nil {
account.Extra = make(map[string]any)
}
quota := make(map[string]any)
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
}
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100)
quota[modelName] = map[string]any{
"remaining": remaining,
"reset_time": modelInfo.QuotaInfo.ResetTime,
}
}
account.Extra["quota"] = quota
account.Extra["last_quota_check"] = time.Now().Format(time.RFC3339)
}

View File

@@ -0,0 +1,145 @@
package service
import (
"context"
"errors"
"log"
"strconv"
"strings"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
)
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenProvider(
accountRepo AccountRepository,
tokenCache AntigravityTokenCache,
antigravityOAuthService *AntigravityOAuthService,
) *AntigravityTokenProvider {
return &AntigravityTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
cacheKey := antigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2. 如果即将过期则刷新
expiresAt := parseAntigravityExpiresAt(account)
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = parseAntigravityExpiresAt(account)
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
}
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", err
}
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = parseAntigravityExpiresAt(account)
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func antigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}
func parseAntigravityExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}

View File

@@ -0,0 +1,57 @@
package service
import (
"context"
"strconv"
"time"
)
// AntigravityTokenRefresher 实现 TokenRefresher 接口
type AntigravityTokenRefresher struct {
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
return &AntigravityTokenRefresher{
antigravityOAuthService: antigravityOAuthService,
}
}
// CanRefresh 检查是否可以刷新此账户
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查账户是否需要刷新
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < refreshWindow
}
// Refresh 执行 token 刷新
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
return newCredentials, nil
}

View File

@@ -18,9 +18,10 @@ const (
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
)
// Account type constants

View File

@@ -0,0 +1,777 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// testConfig 返回一个用于测试的默认配置
func testConfig() *config.Config {
return &config.Config{RunMode: config.RunModeStandard}
}
// mockAccountRepoForPlatform 单平台测试用的 mock
type mockAccountRepoForPlatform struct {
accounts []Account
accountsByID map[int64]*Account
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
}
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForPlatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
if m.listPlatformFunc != nil {
return m.listPlatformFunc(ctx, platform)
}
var result []Account
for _, acc := range m.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
return m.ListSchedulableByPlatform(ctx, platform)
}
// Stub methods to implement AccountRepository interface
func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Account) error {
return nil
}
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
return nil
}
func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListActive(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) UpdateLastUsed(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ListSchedulable(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
platformSet[p] = true
}
for _, acc := range m.accounts {
if platformSet[acc.Platform] && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
return 0, nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
// mockGatewayCacheForPlatform 单平台测试用的 cache mock
type mockGatewayCacheForPlatform struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
m.sessionBindings[sessionHash] = accountID
return nil
}
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
return nil
}
func ptr[T any](v T) *T {
return &v
}
// TestGatewayService_SelectAccountForModelWithPlatform_Anthropic 测试 anthropic 单平台选择
func TestGatewayService_SelectAccountForModelWithPlatform_Anthropic(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的 anthropic 账户")
require.Equal(t, PlatformAnthropic, acc.Platform, "应只返回 anthropic 平台账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_Antigravity 测试 antigravity 单平台选择
func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform, "应只返回 antigravity 平台账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed 测试优先级和最后使用时间
func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t *testing.T) {
ctx := context.Background()
now := time.Now()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
}
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
func TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
excludedIDs := map[int64]struct{}{1: {}, 2: {}}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
}
// TestGatewayService_SelectAccountForModelWithPlatform_Schedulability 测试账户可调度性检查
func TestGatewayService_SelectAccountForModelWithPlatform_Schedulability(t *testing.T) {
ctx := context.Background()
now := time.Now()
tests := []struct {
name string
accounts []Account
expectedID int64
}{
{
name: "过载账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(1 * time.Hour))},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "限流账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, RateLimitResetAt: ptr(now.Add(1 * time.Hour))},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "非active账户被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: "error", Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "schedulable=false被跳过",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 2,
},
{
name: "过期的过载账户可调度",
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(-1 * time.Hour))},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
expectedID: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: tt.accounts,
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, tt.expectedID, acc.ID)
})
}
}
// TestGatewayService_SelectAccountForModelWithPlatform_StickySession 测试粘性会话
func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testing.T) {
ctx := context.Background()
t.Run("粘性会话命中-同平台", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
})
t.Run("粘性会话不匹配平台-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 粘性会话绑定但平台不匹配
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1}, // 绑定 antigravity 账户
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
// 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户平台不匹配,应降级选择同平台账户")
require.Equal(t, PlatformAnthropic, acc.Platform)
})
t.Run("粘性会话账户被排除-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
excludedIDs := map[int64]struct{}{1: {}}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", excludedIDs, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户被排除,应选择其他账户")
})
t.Run("粘性会话账户不可调度-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: "error", Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户不可调度,应选择其他账户")
})
}
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "Anthropic平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformAnthropic},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Anthropic平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{
name: "Anthropic平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户包含启用混合调度的antigravity")
})
t.Run("混合调度-过滤未启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "未启用mixed_scheduling的antigravity账户应被过滤")
require.Equal(t, PlatformAnthropic, acc.Platform)
})
t.Run("混合调度-粘性会话命中启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 2},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
})
t.Run("混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 2},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "粘性会话绑定的账户未启用mixed_scheduling应降级选择anthropic账户")
})
t.Run("混合调度-仅有启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
})
t.Run("混合调度-无可用账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
})
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
tests := []struct {
name string
account Account
expected bool
}{
{
name: "非antigravity平台-返回false",
account: Account{Platform: PlatformAnthropic},
expected: false,
},
{
name: "antigravity平台-无extra-返回false",
account: Account{Platform: PlatformAntigravity},
expected: false,
},
{
name: "antigravity平台-extra无mixed_scheduling-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{}},
expected: false,
},
{
name: "antigravity平台-mixed_scheduling=false-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": false}},
expected: false,
},
{
name: "antigravity平台-mixed_scheduling=true-返回true",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": true}},
expected: true,
},
{
name: "antigravity平台-mixed_scheduling非bool类型-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": "true"}},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsMixedSchedulingEnabled()
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -93,6 +94,7 @@ func (e *UpstreamFailoverError) Error() string {
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
@@ -109,6 +111,7 @@ type GatewayService struct {
// NewGatewayService creates a new GatewayService
func NewGatewayService(
accountRepo AccountRepository,
groupRepo GroupRepository,
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
@@ -123,6 +126,7 @@ func NewGatewayService(
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
@@ -291,16 +295,53 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if hasForcePlatform && groupID != nil {
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err == nil {
return account, nil
}
// 分组中找不到,回退查询全部该平台账户
groupID = nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// 续期粘性会话
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
@@ -310,16 +351,16 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
}
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
// 2. 获取可调度账号列表(平台)
var accounts []Account
var err error
if s.cfg.RunMode == config.RunModeSimple {
// 简易模式:忽略 groupID查询所有可用账号
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -332,19 +373,16 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 检查模型支持
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
// 优先选择priority值更小的priority值越小优先级越高
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
// 优先级相同时,选最久未用的
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
@@ -377,6 +415,126 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return selected, nil
}
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity}
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
}
// 2. 获取可调度账号列表
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
@@ -1116,6 +1274,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
// Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
var req struct {

View File

@@ -18,6 +18,7 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
@@ -33,26 +34,32 @@ const (
)
type GeminiMessagesCompatService struct {
accountRepo AccountRepository
cache GatewayCache
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
accountRepo AccountRepository
groupRepo GroupRepository
cache GatewayCache
tokenProvider *GeminiTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
}
func NewGeminiMessagesCompatService(
accountRepo AccountRepository,
groupRepo GroupRepository,
cache GatewayCache,
tokenProvider *GeminiTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
antigravityGatewayService *AntigravityGatewayService,
) *GeminiMessagesCompatService {
return &GeminiMessagesCompatService{
accountRepo: accountRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
accountRepo: accountRepo,
groupRepo: groupRepo,
cache: cache,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
}
}
@@ -66,26 +73,71 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
} else {
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
var queryPlatforms []string
if useMixedScheduling {
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
} else {
queryPlatforms = []string{platform}
}
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
}
if valid {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -97,7 +149,12 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
// 混合调度模式下原生平台直接通过antigravity 需要启用 mixed_scheduling
// 非混合调度模式antigravity 分组):不需要过滤
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
@@ -139,6 +196,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
}
return account.IsModelSupported(requestedModel)
}
// GetAntigravityGatewayService 返回 AntigravityGatewayService
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
return s.antigravityGatewayService
}
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
}
if err != nil {
return false, err
}
return len(accounts) > 0, nil
}
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
//
@@ -1798,7 +1883,7 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if statusCode != 429 {
return
}
resetAt := parseGeminiRateLimitResetTime(body)
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
ra := time.Now().Add(5 * time.Minute)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
@@ -1807,7 +1892,8 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
}
func parseGeminiRateLimitResetTime(body []byte) *int64 {
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
// Try to parse metadata.quotaResetDelay like "12.345s"
var parsed map[string]any
if err := json.Unmarshal(body, &parsed); err == nil {

View File

@@ -0,0 +1,493 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// mockAccountRepoForGemini Gemini 测试用的 mock
type mockAccountRepoForGemini struct {
accounts []Account
accountsByID map[int64]*Account
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
if acc, ok := m.accountsByID[id]; ok {
return acc, nil
}
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range m.accounts {
if acc.Platform == platform && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
// 测试时不区分 groupID直接按 platform 过滤
return m.ListSchedulableByPlatform(ctx, platform)
}
// Stub methods to implement AccountRepository interface
func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
platformSet[p] = true
}
for _, acc := range m.accounts {
if platformSet[acc.Platform] && acc.IsSchedulable() {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return nil
}
func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
return 0, nil
}
// Verify interface implementation
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type mockGroupRepoForGemini struct {
groups map[int64]*Group
}
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, errors.New("group not found")
}
// Stub methods to implement GroupRepository interface
func (m *mockGroupRepoForGemini) Create(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGemini) Update(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
m.sessionBindings[sessionHash] = accountID
return nil
}
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
// 无分组时使用 gemini 平台
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的 gemini 账户")
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被选择
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{
groups: map[int64]*Group{
1: {ID: 1, Platform: PlatformAntigravity},
},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
groupID := int64(1)
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform, "antigravity 分组应只返回 antigravity 账户")
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred 测试 OAuth 优先
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户")
require.Equal(t, AccountTypeOAuth, acc.Type)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts 测试无可用账户
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available")
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession 测试粘性会话
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
ctx := context.Background()
t.Run("粘性会话命中-同平台", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
// 注意:缓存键使用 "gemini:" 前缀
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
})
t.Run("粘性会话平台不匹配-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 粘性会话绑定
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1}, // 绑定 antigravity 账户
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
// 无分组时使用 gemini 平台,粘性会话绑定的 antigravity 账户平台不匹配
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "粘性会话账户平台不匹配,应降级选择 gemini 账户")
require.Equal(t, PlatformGemini, acc.Platform)
})
t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
// 缓存键没有 "gemini:" 前缀,不应命中
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"session-123": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
})
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) {
tests := []struct {
name string
platform string
expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
}{
{
name: "Gemini平台走ForwardNative",
platform: PlatformGemini,
expectedService: "gemini",
},
{
name: "Antigravity平台走ForwardGemini",
platform: PlatformAntigravity,
expectedService: "antigravity",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: tt.platform}
// 模拟 Handler 层的路由逻辑
var serviceName string
if account.Platform == PlatformAntigravity {
serviceName = "antigravity"
} else {
serviceName = "gemini"
}
require.Equal(t, tt.expectedService, serviceName,
"平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService)
})
}
}
func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
svc := &GeminiMessagesCompatService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "Gemini平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformGemini},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -27,6 +27,7 @@ func NewTokenRefreshService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
@@ -40,6 +41,7 @@ func NewTokenRefreshService(
NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService),
NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
}
return s

View File

@@ -17,7 +17,7 @@ type BuildInfo struct {
func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
svc := NewPricingService(cfg, remoteClient)
if err := svc.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
// Pricing service initialization failure should not block startup, use fallback prices
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
return svc, nil
@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
cfg *config.Config,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg)
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
svc.Start()
return svc
}
@@ -53,6 +54,18 @@ func ProvideTimingWheelService() *TimingWheelService {
return svc
}
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
func ProvideAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
oauthSvc *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
svc.Start()
return svc
}
// ProvideDeferredService creates and starts DeferredService
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
@@ -81,8 +94,11 @@ var ProviderSet = wire.NewSet(
NewOAuthService,
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,
NewAntigravityTokenProvider,
NewAntigravityGatewayService,
NewRateLimitService,
NewAccountUsageService,
NewAccountTestService,
@@ -98,4 +114,5 @@ var ProviderSet = wire.NewSet(
ProvideTokenRefreshService,
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
)