fix(backend): 适配重构后的架构修复 Gemini OAuth 集成

## 主要修改

1. **移除 model 包引用**
   - 删除所有 `internal/model` 包的 import
   - 使用 service 包中的类型定义(Account, Platform常量等)

2. **修复类型转换**
   - JSONB → map[string]any
   - 添加 mergeJSONB 辅助函数
   - 添加 Account.IsGemini() 方法

3. **更新中间件调用**
   - GetUserFromContext → GetAuthSubjectFromContext
   - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象)

4. **修复 handler 层**
   - 更新 gemini_v1beta_handler.go
   - 修正 billing 检查和 usage 记录

## 影响范围
- backend/internal/service/gemini_*.go
- backend/internal/service/account_test_service.go
- backend/internal/service/crs_sync_service.go
- backend/internal/handler/gemini_v1beta_handler.go
- backend/internal/handler/gateway_handler.go
- backend/internal/handler/admin/account_handler.go
This commit is contained in:
IanShaw027
2025-12-26 22:07:55 +08:00
parent bfcd9501c2
commit 9db52838b5
10 changed files with 100 additions and 87 deletions

View File

@@ -18,7 +18,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
@@ -62,31 +61,31 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider
}
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
@@ -106,7 +105,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
@@ -139,13 +138,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
// 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) {
var accounts []model.Account
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -154,17 +153,17 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return nil, errors.New("no available Gemini accounts")
}
rank := func(a *model.Account) int {
rank := func(a *Account) int {
if a == nil {
return 999
}
switch a.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
return 9
case model.AccountTypeOAuth:
case AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1
}
@@ -178,7 +177,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
}
}
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
if selected == nil {
@@ -204,7 +203,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
@@ -221,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return selected, nil
}
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
var req struct {
@@ -237,7 +236,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -254,13 +253,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream
if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
}
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -291,7 +290,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = "x-request-id"
case model.AccountTypeOAuth:
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
@@ -476,7 +475,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
@@ -497,7 +496,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -508,7 +507,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
useUpstreamStream := stream
upstreamAction := action
if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
upstreamAction = "streamGenerateContent"
@@ -519,7 +518,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -546,7 +545,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = "x-request-id"
case model.AccountTypeOAuth:
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
@@ -704,7 +703,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
c.Header("x-request-id", requestID)
}
isOAuth := account.Type == model.AccountTypeOAuth
isOAuth := account.Type == AccountTypeOAuth
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@@ -776,13 +775,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool {
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
case 403:
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
if account == nil || account.Type != model.AccountTypeOAuth {
if account == nil || account.Type != AccountTypeOAuth {
return false
}
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
@@ -1599,7 +1598,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
// endpoints like /v1beta/models and /v1beta/models/{model}.
//
// This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) {
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
if account == nil {
return nil, errors.New("account is nil")
}
@@ -1625,13 +1624,13 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
}
req.Header.Set("x-goog-api-key", apiKey)
case model.AccountTypeOAuth:
case AccountTypeOAuth:
if s.tokenProvider == nil {
return nil, errors.New("gemini token provider not configured")
}
@@ -1766,7 +1765,7 @@ func asInt(v any) (int, bool) {
}
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) {
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return