Merge branch 'main' into test-dev
This commit is contained in:
@@ -346,3 +346,20 @@ func (a *Account) IsOpenAITokenExpired() bool {
|
||||
}
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
|
||||
// IsMixedSchedulingEnabled 检查 antigravity 账户是否启用混合调度
|
||||
// 启用后可参与 anthropic/gemini 分组的账户调度
|
||||
func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["mixed_scheduling"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ type AccountRepository interface {
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
|
||||
823
backend/internal/service/antigravity_gateway_service.go
Normal file
823
backend/internal/service/antigravity_gateway_service.go
Normal file
@@ -0,0 +1,823 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityStickySessionTTL = time.Hour
|
||||
antigravityMaxRetries = 5
|
||||
antigravityRetryBaseDelay = 1 * time.Second
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// Antigravity 直接支持的模型
|
||||
var antigravitySupportedModels = map[string]bool{
|
||||
"claude-opus-4-5-thinking": true,
|
||||
"claude-sonnet-4-5": true,
|
||||
"claude-sonnet-4-5-thinking": true,
|
||||
"gemini-2.5-flash": true,
|
||||
"gemini-2.5-flash-lite": true,
|
||||
"gemini-2.5-flash-thinking": true,
|
||||
"gemini-3-flash": true,
|
||||
"gemini-3-pro-low": true,
|
||||
"gemini-3-pro-high": true,
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image": true,
|
||||
}
|
||||
|
||||
// Antigravity 系统默认模型映射表(不支持 → 支持)
|
||||
var antigravityModelMapping = map[string]string{
|
||||
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
|
||||
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
|
||||
"claude-opus-4": "claude-opus-4-5-thinking",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
|
||||
"claude-haiku-4": "gemini-3-flash",
|
||||
"claude-haiku-4-5": "gemini-3-flash",
|
||||
"claude-3-haiku-20240307": "gemini-3-flash",
|
||||
"claude-haiku-4-5-20251001": "gemini-3-flash",
|
||||
// 生图模型:官方名 → Antigravity 内部名
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
}
|
||||
|
||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||
type AntigravityGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
_ GatewayCache,
|
||||
tokenProvider *AntigravityTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTokenProvider 返回 token provider
|
||||
func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider {
|
||||
return s.tokenProvider
|
||||
}
|
||||
|
||||
// getMappedModel 获取映射后的模型名
|
||||
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
||||
// 1. 优先使用账户级映射(复用现有方法)
|
||||
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 2. 系统默认映射
|
||||
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 3. Gemini 模型透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// 4. Claude 前缀透传直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// 5. 默认值
|
||||
return "claude-sonnet-4-5"
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否被支持
|
||||
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
||||
// 直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return true
|
||||
}
|
||||
// 可映射的模型
|
||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return true
|
||||
}
|
||||
// Gemini 前缀透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return true
|
||||
}
|
||||
// Claude 模型支持(通过默认映射)
|
||||
if strings.HasPrefix(requestedModel, "claude-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wrapV1InternalRequest 包装请求为 v1internal 格式
|
||||
func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model string, originalBody []byte) ([]byte, error) {
|
||||
var request any
|
||||
if err := json.Unmarshal(originalBody, &request); err != nil {
|
||||
return nil, fmt.Errorf("解析请求体失败: %w", err)
|
||||
}
|
||||
|
||||
wrapped := map[string]any{
|
||||
"project": projectID,
|
||||
"requestId": "agent-" + uuid.New().String(),
|
||||
"userAgent": "sub2api",
|
||||
"requestType": "agent",
|
||||
"model": model,
|
||||
"request": request,
|
||||
}
|
||||
|
||||
return json.Marshal(wrapped)
|
||||
}
|
||||
|
||||
// unwrapV1InternalResponse 解包 v1internal 响应
|
||||
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
|
||||
var outer map[string]any
|
||||
if err := json.Unmarshal(body, &outer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp, ok := outer["response"]; ok {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析 Claude 请求
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
|
||||
originalModel := claudeReq.Model
|
||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||
if mappedModel != claudeReq.Model {
|
||||
log.Printf("Antigravity model mapping: %s -> %s (account: %s)", claudeReq.Model, mappedModel, account.Name)
|
||||
}
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("antigravity token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取 project_id
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
return nil, errors.New("project_id not found in credentials")
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 转换 Claude 请求为 Gemini 格式
|
||||
geminiBody, err := antigravity.TransformClaudeToGemini(&claudeReq, projectID, mappedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
|
||||
// 构建上游 URL
|
||||
action := "generateContent"
|
||||
if claudeReq.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, action)
|
||||
if claudeReq.Stream {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
}
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
if claudeReq.Stream {
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleClaudeNonStreamingResponse(c, resp, originalModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
if strings.TrimSpace(originalModel) == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||||
}
|
||||
if strings.TrimSpace(action) == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing action in URL")
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent", "countTokens":
|
||||
// ok
|
||||
default:
|
||||
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
|
||||
}
|
||||
|
||||
mappedModel := s.getMappedModel(account, originalModel)
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("antigravity token provider not configured")
|
||||
}
|
||||
accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 access_token 失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取 project_id
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
return nil, errors.New("project_id not found in credentials")
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建上游 URL
|
||||
upstreamAction := action
|
||||
if action == "generateContent" && stream {
|
||||
upstreamAction = "streamGenerateContent"
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:%s", antigravity.BaseURL, upstreamAction)
|
||||
if stream || upstreamAction == "streamGenerateContent" {
|
||||
fullURL += "?alt=sse"
|
||||
}
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
upstreamReq.Header.Set("Content-Type", "application/json")
|
||||
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
upstreamReq.Header.Set("User-Agent", antigravity.UserAgent)
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if action == "countTokens" {
|
||||
estimated := estimateGeminiCountTokens(body)
|
||||
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 解包并返回错误
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, unwrapped)
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
if stream || upstreamAction == "streamGenerateContent" {
|
||||
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
firstTokenMs = streamRes.firstTokenMs
|
||||
} else {
|
||||
usageResp, err := s.handleGeminiNonStreamingResponse(c, resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = usageResp
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 429, 500, 502, 503, 504, 529:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
func sleepAntigravityBackoff(attempt int) {
|
||||
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||
if statusCode == 429 {
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
if resetAt == nil {
|
||||
// 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟
|
||||
defaultDur := 1 * time.Minute
|
||||
if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) {
|
||||
defaultDur = 5 * time.Minute
|
||||
}
|
||||
ra := time.Now().Add(defaultDur)
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||
return
|
||||
}
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
||||
return
|
||||
}
|
||||
// 其他错误码继续使用 rateLimitService
|
||||
if s.rateLimitService == nil {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
}
|
||||
|
||||
type antigravityStreamResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) {
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "text/event-stream; charset=utf-8"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if len(line) > 0 {
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if strings.HasPrefix(trimmed, "data:") {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if payload == "" || payload == "[DONE]" {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// 解包 v1internal 响应
|
||||
inner, parseErr := s.unwrapV1InternalResponse([]byte(payload))
|
||||
if parseErr == nil && inner != nil {
|
||||
payload = string(inner)
|
||||
}
|
||||
|
||||
// 解析 usage
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(inner, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", payload)
|
||||
flusher.Flush()
|
||||
}
|
||||
} else {
|
||||
_, _ = io.WriteString(c.Writer, line)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleGeminiNonStreamingResponse(c *gin.Context, resp *http.Response) (*ClaudeUsage, error) {
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解包 v1internal 响应
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(unwrapped, &parsed) == nil {
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
c.Data(resp.StatusCode, "application/json", unwrapped)
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", unwrapped)
|
||||
return &ClaudeUsage{}, nil
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": errType, "message": message},
|
||||
})
|
||||
return fmt.Errorf("%s", message)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
|
||||
// 记录上游错误详情便于调试
|
||||
log.Printf("Antigravity upstream error %d: %s", upstreamStatus, string(body))
|
||||
|
||||
var statusCode int
|
||||
var errType, errMsg string
|
||||
|
||||
switch upstreamStatus {
|
||||
case 400:
|
||||
statusCode = http.StatusBadRequest
|
||||
errType = "invalid_request_error"
|
||||
errMsg = "Invalid request"
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "authentication_error"
|
||||
errMsg = "Upstream authentication failed"
|
||||
case 403:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "permission_error"
|
||||
errMsg = "Upstream access forbidden"
|
||||
case 429:
|
||||
statusCode = http.StatusTooManyRequests
|
||||
errType = "rate_limit_error"
|
||||
errMsg = "Upstream rate limit exceeded"
|
||||
case 529:
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
errType = "overloaded_error"
|
||||
errMsg = "Upstream service overloaded"
|
||||
default:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream request failed"
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": errType, "message": errMsg},
|
||||
})
|
||||
return fmt.Errorf("upstream error: %d", upstreamStatus)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||
statusStr := "UNKNOWN"
|
||||
switch status {
|
||||
case 400:
|
||||
statusStr = "INVALID_ARGUMENT"
|
||||
case 404:
|
||||
statusStr = "NOT_FOUND"
|
||||
case 429:
|
||||
statusStr = "RESOURCE_EXHAUSTED"
|
||||
case 500:
|
||||
statusStr = "INTERNAL"
|
||||
case 502, 503:
|
||||
statusStr = "UNAVAILABLE"
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": statusStr,
|
||||
},
|
||||
})
|
||||
return fmt.Errorf("%s", message)
|
||||
}
|
||||
|
||||
// handleClaudeNonStreamingResponse 处理 Claude 非流式响应(Gemini → Claude 转换)
|
||||
func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
|
||||
}
|
||||
|
||||
// 转换 Gemini 响应为 Claude 格式
|
||||
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(body, originalModel)
|
||||
if err != nil {
|
||||
log.Printf("Transform Gemini to Claude failed: %v, body: %s", err, string(body))
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", claudeResp)
|
||||
|
||||
// 转换为 service.ClaudeUsage
|
||||
usage := &ClaudeUsage{
|
||||
InputTokens: agUsage.InputTokens,
|
||||
OutputTokens: agUsage.OutputTokens,
|
||||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// handleClaudeStreamingResponse 处理 Claude 流式响应(Gemini SSE → Claude SSE 转换)
|
||||
func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*antigravityStreamResult, error) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
processor := antigravity.NewStreamingProcessor(originalModel)
|
||||
var firstTokenMs *int
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||||
if agUsage == nil {
|
||||
return &ClaudeUsage{}
|
||||
}
|
||||
return &ClaudeUsage{
|
||||
InputTokens: agUsage.InputTokens,
|
||||
OutputTokens: agUsage.OutputTokens,
|
||||
CacheCreationInputTokens: agUsage.CacheCreationInputTokens,
|
||||
CacheReadInputTokens: agUsage.CacheReadInputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
if len(line) > 0 {
|
||||
// 处理 SSE 行,转换为 Claude 格式
|
||||
claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n"))
|
||||
|
||||
if len(claudeEvents) > 0 {
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil {
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
}
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 发送结束事件
|
||||
finalEvents, agUsage := processor.Finish()
|
||||
if len(finalEvents) > 0 {
|
||||
_, _ = c.Writer.Write(finalEvents)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
269
backend/internal/service/antigravity_model_mapping_test.go
Normal file
269
backend/internal/service/antigravity_model_mapping_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAntigravityModelSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持的模型
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
|
||||
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
|
||||
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
|
||||
|
||||
// 可映射的模型
|
||||
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
|
||||
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
|
||||
{"可映射 - claude-opus-4", "claude-opus-4", true},
|
||||
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
|
||||
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
|
||||
|
||||
// Gemini 前缀透传
|
||||
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
|
||||
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
|
||||
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
|
||||
|
||||
// Claude 前缀兜底
|
||||
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
|
||||
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
|
||||
{"Claude前缀 - claude-future-version", "claude-future-version", true},
|
||||
|
||||
// 不支持的模型
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - gpt-4o", "gpt-4o", false},
|
||||
{"不支持 - llama-3", "llama-3", false},
|
||||
{"不支持 - mistral-7b", "mistral-7b", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsAntigravityModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
accountMapping map[string]string
|
||||
expected string
|
||||
}{
|
||||
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
|
||||
{
|
||||
name: "账户映射优先",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
|
||||
expected: "custom-model",
|
||||
},
|
||||
{
|
||||
name: "账户映射覆盖系统映射",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
|
||||
expected: "my-opus",
|
||||
},
|
||||
|
||||
// 2. 系统默认映射
|
||||
{
|
||||
name: "系统映射 - claude-3-5-sonnet-20241022",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-5-sonnet-20240620",
|
||||
requestedModel: "claude-3-5-sonnet-20240620",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-opus-4",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-opus-4-5-20251101",
|
||||
requestedModel: "claude-opus-4-5-20251101",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4 → gemini-3-flash",
|
||||
requestedModel: "claude-haiku-4",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4-5 → gemini-3-flash",
|
||||
requestedModel: "claude-haiku-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-haiku-20240307 → gemini-3-flash",
|
||||
requestedModel: "claude-3-haiku-20240307",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash",
|
||||
requestedModel: "claude-haiku-4-5-20251001",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-sonnet-4-5-20250929",
|
||||
requestedModel: "claude-sonnet-4-5-20250929",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
|
||||
// 3. Gemini 透传
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-flash",
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-1.5-pro",
|
||||
requestedModel: "gemini-1.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-1.5-pro",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
requestedModel: "gemini-future-model",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-future-model",
|
||||
},
|
||||
|
||||
// 4. 直接支持的模型
|
||||
{
|
||||
name: "直接支持 - claude-sonnet-4-5",
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "直接支持 - claude-opus-4-5-thinking",
|
||||
requestedModel: "claude-opus-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "直接支持 - claude-sonnet-4-5-thinking",
|
||||
requestedModel: "claude-sonnet-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
|
||||
// 5. 默认值 fallback(未知 claude 模型)
|
||||
{
|
||||
name: "默认值 - claude-unknown",
|
||||
requestedModel: "claude-unknown",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "默认值 - claude-3-opus-20240229",
|
||||
requestedModel: "claude-3-opus-20240229",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
if tt.accountMapping != nil {
|
||||
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
|
||||
mappingAny := make(map[string]any)
|
||||
for k, v := range tt.accountMapping {
|
||||
mappingAny[k] = v
|
||||
}
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": mappingAny,
|
||||
}
|
||||
}
|
||||
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expected string
|
||||
}{
|
||||
// 空字符串回退到默认值
|
||||
{"空字符串", "", "claude-sonnet-4-5"},
|
||||
|
||||
// 非 claude/gemini 前缀回退到默认值
|
||||
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
|
||||
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
|
||||
|
||||
// 可映射
|
||||
{"可映射 - claude-opus-4", "claude-opus-4", true},
|
||||
|
||||
// 前缀透传
|
||||
{"Gemini前缀", "gemini-unknown", true},
|
||||
{"Claude前缀", "claude-unknown", true},
|
||||
|
||||
// 不支持
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.IsModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
267
backend/internal/service/antigravity_oauth_service.go
Normal file
267
backend/internal/service/antigravity_oauth_service.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
type AntigravityOAuthService struct {
|
||||
sessionStore *antigravity.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
|
||||
return &AntigravityOAuthService{
|
||||
sessionStore: antigravity.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// AntigravityAuthURLResult is the result of generating an authorization URL
|
||||
type AntigravityAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成 Google OAuth 授权链接
|
||||
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
|
||||
state, err := antigravity.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 state 失败: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := antigravity.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
|
||||
}
|
||||
|
||||
sessionID, err := antigravity.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
session := &antigravity.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
|
||||
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
return &AntigravityAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AntigravityExchangeCodeInput 交换 code 的输入
|
||||
type AntigravityExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// AntigravityTokenInfo token 信息
|
||||
type AntigravityTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session 不存在或已过期")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
return nil, fmt.Errorf("state 无效")
|
||||
}
|
||||
|
||||
// 确定代理 URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 交换 token
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除 session
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间(减去 5 分钟安全窗口)
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
|
||||
result := &AntigravityTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||
} else {
|
||||
result.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id
|
||||
loadResp, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
||||
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 token
|
||||
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
tokenResp, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err == nil {
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
return &AntigravityTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if isNonRetryableAntigravityOAuthError(err) {
|
||||
return nil, err
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||
msg := err.Error()
|
||||
nonRetryable := []string{
|
||||
"invalid_grant",
|
||||
"invalid_client",
|
||||
"unauthorized_client",
|
||||
"access_denied",
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RefreshAccountToken 刷新账户的 token
|
||||
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
|
||||
}
|
||||
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return nil, fmt.Errorf("无可用的 refresh_token")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留原有的 project_id 和 email
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||
if existingEmail != "" {
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.TokenType != "" {
|
||||
creds["token_type"] = tokenInfo.TokenType
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (s *AntigravityOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
225
backend/internal/service/antigravity_quota_refresher.go
Normal file
225
backend/internal/service/antigravity_quota_refresher.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// AntigravityQuotaRefresher 定时刷新 Antigravity 账户的配额信息
|
||||
type AntigravityQuotaRefresher struct {
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
cfg *config.TokenRefreshConfig
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaRefresher 创建配额刷新器
|
||||
func NewAntigravityQuotaRefresher(
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
_ *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *AntigravityQuotaRefresher {
|
||||
return &AntigravityQuotaRefresher{
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动后台配额刷新服务
|
||||
func (r *AntigravityQuotaRefresher) Start() {
|
||||
if !r.cfg.Enabled {
|
||||
log.Println("[AntigravityQuota] Service disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
r.wg.Add(1)
|
||||
go r.refreshLoop()
|
||||
|
||||
log.Printf("[AntigravityQuota] Service started (check every %d minutes)", r.cfg.CheckIntervalMinutes)
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (r *AntigravityQuotaRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
log.Println("[AntigravityQuota] Service stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
func (r *AntigravityQuotaRefresher) refreshLoop() {
|
||||
defer r.wg.Done()
|
||||
|
||||
checkInterval := time.Duration(r.cfg.CheckIntervalMinutes) * time.Minute
|
||||
if checkInterval < time.Minute {
|
||||
checkInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时立即执行一次
|
||||
r.processRefresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
r.processRefresh()
|
||||
case <-r.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processRefresh 执行一次刷新
|
||||
func (r *AntigravityQuotaRefresher) processRefresh() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 查询所有 active 的账户,然后过滤 antigravity 平台
|
||||
allAccounts, err := r.accountRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[AntigravityQuota] Failed to list accounts: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤 antigravity 平台账户
|
||||
var accounts []Account
|
||||
for _, acc := range allAccounts {
|
||||
if acc.Platform == PlatformAntigravity {
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
refreshed, failed := 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
if err := r.refreshAccountQuota(ctx, account); err != nil {
|
||||
log.Printf("[AntigravityQuota] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
failed++
|
||||
} else {
|
||||
refreshed++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[AntigravityQuota] Cycle complete: total=%d, refreshed=%d, failed=%d",
|
||||
len(accounts), refreshed, failed)
|
||||
}
|
||||
|
||||
// refreshAccountQuota 刷新单个账户的配额
|
||||
func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, account *Account) error {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
if accessToken == "" || projectID == "" {
|
||||
return nil // 没有有效凭证,跳过
|
||||
}
|
||||
|
||||
// token 过期则跳过,由 TokenRefreshService 负责刷新
|
||||
if r.isTokenExpired(account) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取代理 URL
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := r.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 获取账户类型(tier)
|
||||
loadResp, _ := client.LoadCodeAssist(ctx, accessToken)
|
||||
if loadResp != nil {
|
||||
r.updateAccountTier(account, loadResp)
|
||||
}
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析配额数据并更新 extra 字段
|
||||
r.updateAccountQuota(account, modelsResp)
|
||||
|
||||
// 保存到数据库
|
||||
return r.accountRepo.Update(ctx, account)
|
||||
}
|
||||
|
||||
// isTokenExpired 检查 token 是否过期
|
||||
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
|
||||
expiresAt := parseAntigravityExpiresAt(account)
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 提前 5 分钟认为过期
|
||||
return time.Now().Add(5 * time.Minute).After(*expiresAt)
|
||||
}
|
||||
|
||||
// updateAccountTier 更新账户类型信息
|
||||
func (r *AntigravityQuotaRefresher) updateAccountTier(account *Account, loadResp *antigravity.LoadCodeAssistResponse) {
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
|
||||
tier := loadResp.GetTier()
|
||||
if tier != "" {
|
||||
account.Extra["tier"] = tier
|
||||
}
|
||||
|
||||
// 保存不符合条件的原因(如 INELIGIBLE_ACCOUNT)
|
||||
if len(loadResp.IneligibleTiers) > 0 && loadResp.IneligibleTiers[0] != nil {
|
||||
ineligible := loadResp.IneligibleTiers[0]
|
||||
if ineligible.ReasonCode != "" {
|
||||
account.Extra["ineligible_reason_code"] = ineligible.ReasonCode
|
||||
}
|
||||
if ineligible.ReasonMessage != "" {
|
||||
account.Extra["ineligible_reason_message"] = ineligible.ReasonMessage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateAccountQuota 更新账户的配额信息
|
||||
func (r *AntigravityQuotaRefresher) updateAccountQuota(account *Account, modelsResp *antigravity.FetchAvailableModelsResponse) {
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
|
||||
quota := make(map[string]any)
|
||||
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换 remainingFraction (0.0-1.0) 为百分比 (0-100)
|
||||
remaining := int(modelInfo.QuotaInfo.RemainingFraction * 100)
|
||||
|
||||
quota[modelName] = map[string]any{
|
||||
"remaining": remaining,
|
||||
"reset_time": modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
}
|
||||
|
||||
account.Extra["quota"] = quota
|
||||
account.Extra["last_quota_check"] = time.Now().Format(time.RFC3339)
|
||||
}
|
||||
145
backend/internal/service/antigravity_token_provider.go
Normal file
145
backend/internal/service/antigravity_token_provider.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type AntigravityTokenCache = GeminiTokenCache
|
||||
|
||||
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
|
||||
type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache AntigravityTokenCache,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
) *AntigravityTokenProvider {
|
||||
return &AntigravityTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an antigravity oauth account")
|
||||
}
|
||||
|
||||
cacheKey := antigravityTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := parseAntigravityExpiresAt(account)
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = parseAntigravityExpiresAt(account)
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
||||
if p.antigravityOAuthService == nil {
|
||||
return "", errors.New("antigravity oauth service not configured")
|
||||
}
|
||||
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
expiresAt = parseAntigravityExpiresAt(account)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func antigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "ag:" + projectID
|
||||
}
|
||||
return "ag:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func parseAntigravityExpiresAt(account *Account) *time.Time {
|
||||
raw := strings.TrimSpace(account.GetCredential("expires_at"))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
|
||||
t := time.Unix(unixSec, 0)
|
||||
return &t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
57
backend/internal/service/antigravity_token_refresher.go
Normal file
57
backend/internal/service/antigravity_token_refresher.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AntigravityTokenRefresher 实现 TokenRefresher 接口
|
||||
type AntigravityTokenRefresher struct {
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
|
||||
return &AntigravityTokenRefresher{
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否可以刷新此账户
|
||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查账户是否需要刷新
|
||||
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return false
|
||||
}
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行 token 刷新
|
||||
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
@@ -18,9 +18,10 @@ const (
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
|
||||
777
backend/internal/service/gateway_multiplatform_test.go
Normal file
777
backend/internal/service/gateway_multiplatform_test.go
Normal file
@@ -0,0 +1,777 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testConfig 返回一个用于测试的默认配置
|
||||
func testConfig() *config.Config {
|
||||
return &config.Config{RunMode: config.RunModeStandard}
|
||||
}
|
||||
|
||||
// mockAccountRepoForPlatform 单平台测试用的 mock
|
||||
type mockAccountRepoForPlatform struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
if m.listPlatformFunc != nil {
|
||||
return m.listPlatformFunc(ctx, platform)
|
||||
}
|
||||
var result []Account
|
||||
for _, acc := range m.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
// Stub methods to implement AccountRepository interface
|
||||
func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Account) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListActive(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
for _, acc := range m.accounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
|
||||
|
||||
// mockGatewayCacheForPlatform 单平台测试用的 cache mock
|
||||
type mockGatewayCacheForPlatform struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
m.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_Anthropic 测试 anthropic 单平台选择
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_Anthropic(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的 anthropic 账户")
|
||||
require.Equal(t, PlatformAnthropic, acc.Platform, "应只返回 anthropic 平台账户")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_Antigravity 测试 antigravity 单平台选择
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
require.Equal(t, PlatformAntigravity, acc.Platform, "应只返回 antigravity 平台账户")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed 测试优先级和最后使用时间
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: ptr(now.Add(-2 * time.Hour))},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded 测试所有账户被排除
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}, 2: {}}
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_Schedulability 测试账户可调度性检查
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_Schedulability(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accounts []Account
|
||||
expectedID int64
|
||||
}{
|
||||
{
|
||||
name: "过载账户被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "限流账户被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, RateLimitResetAt: ptr(now.Add(1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "非active账户被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: "error", Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "schedulable=false被跳过",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 2,
|
||||
},
|
||||
{
|
||||
name: "过期的过载账户可调度",
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, OverloadUntil: ptr(now.Add(-1 * time.Hour))},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
expectedID: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: tt.accounts,
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, tt.expectedID, acc.ID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_StickySession 测试粘性会话
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("粘性会话命中-同平台", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话不匹配平台-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 粘性会话绑定但平台不匹配
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1}, // 绑定 antigravity 账户
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
// 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户平台不匹配,应降级选择同平台账户")
|
||||
require.Equal(t, PlatformAnthropic, acc.Platform)
|
||||
})
|
||||
|
||||
t.Run("粘性会话账户被排除-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", excludedIDs, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户被排除,应选择其他账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话账户不可调度-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: "error", Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户不可调度,应选择其他账户")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
svc := &GatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Antigravity平台-支持claude模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-支持gemini模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-不支持gpt模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Anthropic平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformAnthropic},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Anthropic平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
|
||||
},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Anthropic平台-有映射配置-支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
|
||||
},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isModelSupportedByAccount(tt.account, tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
|
||||
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
|
||||
})
|
||||
|
||||
t.Run("混合调度-过滤未启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "未启用mixed_scheduling的antigravity账户应被过滤")
|
||||
require.Equal(t, PlatformAnthropic, acc.Platform)
|
||||
})
|
||||
|
||||
t.Run("混合调度-粘性会话命中启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"session-123": 2},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
|
||||
})
|
||||
|
||||
t.Run("混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"session-123": 2},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户")
|
||||
})
|
||||
|
||||
t.Run("混合调度-仅有启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
require.Equal(t, PlatformAntigravity, acc.Platform)
|
||||
})
|
||||
|
||||
t.Run("混合调度-无可用账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
}
|
||||
|
||||
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
|
||||
func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account Account
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "非antigravity平台-返回false",
|
||||
account: Account{Platform: PlatformAnthropic},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "antigravity平台-无extra-返回false",
|
||||
account: Account{Platform: PlatformAntigravity},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "antigravity平台-extra无mixed_scheduling-返回false",
|
||||
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{}},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "antigravity平台-mixed_scheduling=false-返回false",
|
||||
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": false}},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "antigravity平台-mixed_scheduling=true-返回true",
|
||||
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "antigravity平台-mixed_scheduling非bool类型-返回false",
|
||||
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": "true"}},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.account.IsMixedSchedulingEnabled()
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -93,6 +94,7 @@ func (e *UpstreamFailoverError) Error() string {
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
@@ -109,6 +111,7 @@ type GatewayService struct {
|
||||
// NewGatewayService creates a new GatewayService
|
||||
func NewGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
@@ -123,6 +126,7 @@ func NewGatewayService(
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
@@ -291,16 +295,53 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
// 无分组时只使用原生 anthropic 平台
|
||||
platform = PlatformAnthropic
|
||||
}
|
||||
|
||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
|
||||
if hasForcePlatform && groupID != nil {
|
||||
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
if err == nil {
|
||||
return account, nil
|
||||
}
|
||||
// 分组中找不到,回退查询全部该平台账户
|
||||
groupID = nil
|
||||
}
|
||||
|
||||
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
|
||||
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
@@ -310,16 +351,16 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:忽略 groupID,查询所有可用账号
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -332,19 +373,16 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 检查模型支持
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// 优先选择priority值更小的(priority值越小优先级越高)
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
// 优先级相同时,选最久未用的
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
@@ -377,6 +415,126 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
// Antigravity 平台使用专门的模型支持检查
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
|
||||
func IsAntigravityModelSupported(requestedModel string) bool {
|
||||
// 直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return true
|
||||
}
|
||||
// 可映射的模型
|
||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return true
|
||||
}
|
||||
// Gemini 前缀透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return true
|
||||
}
|
||||
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
|
||||
if strings.HasPrefix(requestedModel, "claude-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
@@ -1116,6 +1274,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
||||
if account.Platform == PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
var req struct {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
|
||||
@@ -33,26 +34,32 @@ const (
|
||||
)
|
||||
|
||||
type GeminiMessagesCompatService struct {
|
||||
accountRepo AccountRepository
|
||||
cache GatewayCache
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
cache GatewayCache
|
||||
tokenProvider *GeminiTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
}
|
||||
|
||||
func NewGeminiMessagesCompatService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
cache GatewayCache,
|
||||
tokenProvider *GeminiTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
) *GeminiMessagesCompatService {
|
||||
return &GeminiMessagesCompatService{
|
||||
accountRepo: accountRepo,
|
||||
cache: cache,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,26 +73,71 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
platform = PlatformGemini
|
||||
}
|
||||
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
var queryPlatforms []string
|
||||
if useMixedScheduling {
|
||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
||||
} else {
|
||||
queryPlatforms = []string{platform}
|
||||
}
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
if account.Platform == platform {
|
||||
valid = true
|
||||
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
valid = true
|
||||
}
|
||||
if valid {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||
if len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -97,7 +149,12 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
|
||||
// 非混合调度模式(antigravity 分组):不需要过滤
|
||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
@@ -139,6 +196,34 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
|
||||
// GetAntigravityGatewayService 返回 AntigravityGatewayService
|
||||
func (s *GeminiMessagesCompatService) GetAntigravityGatewayService() *AntigravityGatewayService {
|
||||
return s.antigravityGatewayService
|
||||
}
|
||||
|
||||
// HasAntigravityAccounts 检查是否有可用的 antigravity 账户
|
||||
func (s *GeminiMessagesCompatService) HasAntigravityAccounts(ctx context.Context, groupID *int64) (bool, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAntigravity)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAntigravity)
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(accounts) > 0, nil
|
||||
}
|
||||
|
||||
// SelectAccountForAIStudioEndpoints selects an account that is likely to succeed against
|
||||
// generativelanguage.googleapis.com (e.g. GET /v1beta/models).
|
||||
//
|
||||
@@ -1798,7 +1883,7 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
||||
if statusCode != 429 {
|
||||
return
|
||||
}
|
||||
resetAt := parseGeminiRateLimitResetTime(body)
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
if resetAt == nil {
|
||||
ra := time.Now().Add(5 * time.Minute)
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||
@@ -1807,7 +1892,8 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
||||
}
|
||||
|
||||
func parseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||
// Try to parse metadata.quotaResetDelay like "12.345s"
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(body, &parsed); err == nil {
|
||||
|
||||
493
backend/internal/service/gemini_multiplatform_test.go
Normal file
493
backend/internal/service/gemini_multiplatform_test.go
Normal file
@@ -0,0 +1,493 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockAccountRepoForGemini Gemini 测试用的 mock
|
||||
type mockAccountRepoForGemini struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
// 测试时不区分 groupID,直接按 platform 过滤
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
// Stub methods to implement AccountRepository interface
|
||||
func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
for _, acc := range m.accounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
|
||||
type mockGroupRepoForGemini struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
// Stub methods to implement GroupRepository interface
|
||||
func (m *mockGroupRepoForGemini) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGemini) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
||||
func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
m.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 无分组时使用 gemini 平台
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的 gemini 账户")
|
||||
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被选择
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
1: {ID: 1, Platform: PlatformAntigravity},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
groupID := int64(1)
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
require.Equal(t, PlatformAntigravity, acc.Platform, "antigravity 分组应只返回 antigravity 账户")
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred 测试 OAuth 优先
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户")
|
||||
require.Equal(t, AccountTypeOAuth, acc.Type)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts 测试无可用账户
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available")
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession 测试粘性会话
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("粘性会话命中-同平台", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
// 注意:缓存键使用 "gemini:" 前缀
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话平台不匹配-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 粘性会话绑定
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1}, // 绑定 antigravity 账户
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 无分组时使用 gemini 平台,粘性会话绑定的 antigravity 账户平台不匹配
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户平台不匹配,应降级选择 gemini 账户")
|
||||
require.Equal(t, PlatformGemini, acc.Platform)
|
||||
})
|
||||
|
||||
t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
// 缓存键没有 "gemini:" 前缀,不应命中
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
// 粘性会话未命中,按优先级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||
func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
|
||||
}{
|
||||
{
|
||||
name: "Gemini平台走ForwardNative",
|
||||
platform: PlatformGemini,
|
||||
expectedService: "gemini",
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台走ForwardGemini",
|
||||
platform: PlatformAntigravity,
|
||||
expectedService: "antigravity",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: tt.platform}
|
||||
|
||||
// 模拟 Handler 层的路由逻辑
|
||||
var serviceName string
|
||||
if account.Platform == PlatformAntigravity {
|
||||
serviceName = "antigravity"
|
||||
} else {
|
||||
serviceName = "gemini"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedService, serviceName,
|
||||
"平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Antigravity平台-支持gemini模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-支持claude模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-不支持gpt模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformGemini},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isModelSupportedByAccount(tt.account, tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ func NewTokenRefreshService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
@@ -40,6 +41,7 @@ func NewTokenRefreshService(
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
NewOpenAITokenRefresher(openaiOAuthService),
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
NewAntigravityTokenRefresher(antigravityOAuthService),
|
||||
}
|
||||
|
||||
return s
|
||||
|
||||
@@ -17,7 +17,7 @@ type BuildInfo struct {
|
||||
func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
|
||||
svc := NewPricingService(cfg, remoteClient)
|
||||
if err := svc.Initialize(); err != nil {
|
||||
// 价格服务初始化失败不应阻止启动,使用回退价格
|
||||
// Pricing service initialization failure should not block startup, use fallback prices
|
||||
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
||||
}
|
||||
return svc, nil
|
||||
@@ -39,9 +39,10 @@ func ProvideTokenRefreshService(
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, cfg)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@@ -53,6 +54,18 @@ func ProvideTimingWheelService() *TimingWheelService {
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
|
||||
func ProvideAntigravityQuotaRefresher(
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
oauthSvc *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *AntigravityQuotaRefresher {
|
||||
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideDeferredService creates and starts DeferredService
|
||||
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
|
||||
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
|
||||
@@ -81,8 +94,11 @@ var ProviderSet = wire.NewSet(
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewGeminiOAuthService,
|
||||
NewAntigravityOAuthService,
|
||||
NewGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
NewAntigravityTokenProvider,
|
||||
NewAntigravityGatewayService,
|
||||
NewRateLimitService,
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
@@ -98,4 +114,5 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideTokenRefreshService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDeferredService,
|
||||
ProvideAntigravityQuotaRefresher,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user