fix(backend): 适配重构后的架构修复 Gemini OAuth 集成
## 主要修改 1. **移除 model 包引用** - 删除所有 `internal/model` 包的 import - 使用 service 包中的类型定义(Account, Platform常量等) 2. **修复类型转换** - JSONB → map[string]any - 添加 mergeJSONB 辅助函数 - 添加 Account.IsGemini() 方法 3. **更新中间件调用** - GetUserFromContext → GetAuthSubjectFromContext - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象) 4. **修复 handler 层** - 更新 gemini_v1beta_handler.go - 修正 billing 检查和 usage 记录 ## 影响范围 - backend/internal/service/gemini_*.go - backend/internal/service/account_test_service.go - backend/internal/service/crs_sync_service.go - backend/internal/handler/gemini_v1beta_handler.go - backend/internal/handler/gateway_handler.go - backend/internal/handler/admin/account_handler.go
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
|
||||
@@ -62,31 +61,31 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
|
||||
return s.tokenProvider
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var accounts []model.Account
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
@@ -106,7 +105,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
||||
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
@@ -139,13 +138,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
// 2) OAuth accounts without project_id (AI Studio OAuth)
|
||||
// 3) OAuth accounts explicitly marked as ai_studio
|
||||
// 4) Any remaining Gemini accounts (fallback)
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -154,17 +153,17 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return nil, errors.New("no available Gemini accounts")
|
||||
}
|
||||
|
||||
rank := func(a *model.Account) int {
|
||||
rank := func(a *Account) int {
|
||||
if a == nil {
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
return 9
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
|
||||
return 1
|
||||
}
|
||||
@@ -178,7 +177,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
}
|
||||
}
|
||||
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if selected == nil {
|
||||
@@ -204,7 +203,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
@@ -221,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var req struct {
|
||||
@@ -237,7 +236,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -254,13 +253,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
var requestIDHeader string
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
useUpstreamStream := req.Stream
|
||||
if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
|
||||
useUpstreamStream = true
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -291,7 +290,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
@@ -476,7 +475,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
if strings.TrimSpace(originalModel) == "" {
|
||||
@@ -497,7 +496,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -508,7 +507,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
useUpstreamStream := stream
|
||||
upstreamAction := action
|
||||
if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
|
||||
useUpstreamStream = true
|
||||
upstreamAction = "streamGenerateContent"
|
||||
@@ -519,7 +518,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -546,7 +545,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
@@ -704,7 +703,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
isOAuth := account.Type == model.AccountTypeOAuth
|
||||
isOAuth := account.Type == AccountTypeOAuth
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
@@ -776,13 +775,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool {
|
||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 429, 500, 502, 503, 504, 529:
|
||||
return true
|
||||
case 403:
|
||||
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
|
||||
if account == nil || account.Type != model.AccountTypeOAuth {
|
||||
if account == nil || account.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
|
||||
@@ -1599,7 +1598,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
// endpoints like /v1beta/models and /v1beta/models/{model}.
|
||||
//
|
||||
// This is used to support Gemini SDKs that call models listing endpoints before generation.
|
||||
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) {
|
||||
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
|
||||
if account == nil {
|
||||
return nil, errors.New("account is nil")
|
||||
}
|
||||
@@ -1625,13 +1624,13 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("gemini token provider not configured")
|
||||
}
|
||||
@@ -1766,7 +1765,7 @@ func asInt(v any) (int, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) {
|
||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user