fix(backend): 适配重构后的架构修复 Gemini OAuth 集成

## 主要修改

1. **移除 model 包引用**
   - 删除所有 `internal/model` 包的 import
   - 使用 service 包中的类型定义(Account, Platform常量等)

2. **修复类型转换**
   - JSONB → map[string]any
   - 添加 mergeJSONB 辅助函数
   - 添加 Account.IsGemini() 方法

3. **更新中间件调用**
   - GetUserFromContext → GetAuthSubjectFromContext
   - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象)

4. **修复 handler 层**
   - 更新 gemini_v1beta_handler.go
   - 修正 billing 检查和 usage 记录

## 影响范围
- backend/internal/service/gemini_*.go
- backend/internal/service/account_test_service.go
- backend/internal/service/crs_sync_service.go
- backend/internal/handler/gemini_v1beta_handler.go
- backend/internal/handler/gateway_handler.go
- backend/internal/handler/admin/account_handler.go
This commit is contained in:
IanShaw027
2025-12-26 22:07:55 +08:00
parent bfcd9501c2
commit 9db52838b5
10 changed files with 100 additions and 87 deletions

View File

@@ -350,7 +350,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v newCredentials[k] = v
} }
} }
} else if account.Platform == model.PlatformGemini { } else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.InternalError(c, "Failed to refresh credentials: "+err.Error())

View File

@@ -128,8 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 选择支持该模型的账号 // 选择支持该模型的账号
var account *model.Account var account *service.Account
if platform == model.PlatformGemini { if platform == service.PlatformGemini {
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
} else { } else {
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
@@ -162,7 +162,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 // 转发请求
var result *service.ForwardResult var result *service.ForwardResult
if platform == model.PlatformGemini { if platform == service.PlatformGemini {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} else { } else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)

View File

@@ -8,7 +8,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -25,7 +24,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key") googleError(c, http.StatusUnauthorized, "Invalid API key")
return return
} }
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini") googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return return
} }
@@ -56,7 +55,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key") googleError(c, http.StatusUnauthorized, "Invalid API key")
return return
} }
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini") googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return return
} }
@@ -94,13 +93,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key") googleError(c, http.StatusUnauthorized, "Invalid API key")
return return
} }
user, ok := middleware.GetUserFromContext(c) authSubject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || user == nil { if !ok {
googleError(c, http.StatusInternalServerError, "User context not found") googleError(c, http.StatusInternalServerError, "User context not found")
return return
} }
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini") googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return return
} }
@@ -130,19 +129,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check // 0) wait queue check
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
} else if !canWait { } else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return return
} }
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), user.ID) defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot // 1) user concurrency slot
streamStarted := false streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, user, stream, &streamStarted) userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil { if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
@@ -152,7 +151,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 2) billing eligibility check (after wait) // 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error()) googleError(c, http.StatusForbidden, err.Error())
return return
} }
@@ -166,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
} }
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account, stream, &streamStarted) accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil { if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
@@ -190,7 +189,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: user, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {

View File

@@ -70,6 +70,10 @@ func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
} }
func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) CanGetUsage() bool { func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth return a.Type == AccountTypeOAuth
} }
@@ -322,3 +326,17 @@ func (a *Account) IsOpenAITokenExpired() bool {
} }
return time.Now().Add(60 * time.Second).After(*expiresAt) return time.Now().Add(60 * time.Second).After(*expiresAt)
} }
// mergeJSONB merges source map into target map (for preserving extra fields during account sync)
func mergeJSONB(target, source map[string]any) map[string]any {
if target == nil {
target = make(map[string]any)
}
if source == nil {
return target
}
for k, v := range source {
target[k] = v
}
return target
}

View File

@@ -387,7 +387,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
} }
// testGeminiAccountConnection tests a Gemini account's connection // testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Determine the model to use // Determine the model to use
@@ -397,7 +397,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
} }
// For API Key accounts with model mapping, map the model // For API Key accounts with model mapping, map the model
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mapping := account.GetModelMapping() mapping := account.GetModelMapping()
if len(mapping) > 0 { if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists { if mappedModel, exists := mapping[testModelID]; exists {
@@ -421,9 +421,9 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error var err error
switch account.Type { switch account.Type {
case model.AccountTypeApiKey: case AccountTypeApiKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case model.AccountTypeOAuth: case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
default: default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
@@ -458,7 +458,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
} }
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts // buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) { func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
return nil, fmt.Errorf("no API key available") return nil, fmt.Errorf("no API key available")
@@ -485,7 +485,7 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
} }
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts // buildGeminiOAuthRequest builds request for Gemini OAuth accounts
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) { func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
if s.geminiTokenProvider == nil { if s.geminiTokenProvider == nil {
return nil, fmt.Errorf("gemini token provider not configured") return nil, fmt.Errorf("gemini token provider not configured")
} }

View File

@@ -772,12 +772,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformGemini, Platform: PlatformGemini,
Type: model.AccountTypeOAuth, Type: AccountTypeOAuth,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: 3, Concurrency: 3,
Priority: clampPriority(src.Priority), Priority: clampPriority(src.Priority),
@@ -803,8 +803,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformGemini existing.Platform = PlatformGemini
existing.Type = model.AccountTypeOAuth existing.Type = AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
@@ -883,12 +883,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformGemini, Platform: PlatformGemini,
Type: model.AccountTypeApiKey, Type: AccountTypeApiKey,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: 3, Concurrency: 3,
Priority: clampPriority(src.Priority), Priority: clampPriority(src.Priority),
@@ -910,8 +910,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformGemini existing.Platform = PlatformGemini
existing.Type = model.AccountTypeApiKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
@@ -1200,7 +1200,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
} }
} }
} }
case model.PlatformGemini: case PlatformGemini:
if s.geminiOAuthService == nil { if s.geminiOAuthService == nil {
return nil return nil
} }

View File

@@ -18,7 +18,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
@@ -62,31 +61,31 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider return s.tokenProvider
} }
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
cacheKey := "gemini:" + sessionHash cacheKey := "gemini:" + sessionHash
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 { if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID) account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil return account, nil
} }
} }
} }
var accounts []model.Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
} }
var selected *model.Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
@@ -106,7 +105,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
// keep selected (never used is preferred) // keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil: case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows). // Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth { if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc selected = acc
} }
default: default:
@@ -139,13 +138,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
// 2) OAuth accounts without project_id (AI Studio OAuth) // 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio // 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback) // 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
var accounts []model.Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -154,17 +153,17 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return nil, errors.New("no available Gemini accounts") return nil, errors.New("no available Gemini accounts")
} }
rank := func(a *model.Account) int { rank := func(a *Account) int {
if a == nil { if a == nil {
return 999 return 999
} }
switch a.Type { switch a.Type {
case model.AccountTypeApiKey: case AccountTypeApiKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" { if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0 return 0
} }
return 9 return 9
case model.AccountTypeOAuth: case AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" { if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1 return 1
} }
@@ -178,7 +177,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
} }
} }
var selected *model.Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
if selected == nil { if selected == nil {
@@ -204,7 +203,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
case acc.LastUsedAt != nil && selected.LastUsedAt == nil: case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected // keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil: case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth { if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc selected = acc
} }
default: default:
@@ -221,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return selected, nil return selected, nil
} }
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
var req struct { var req struct {
@@ -237,7 +236,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model originalModel := req.Model
mappedModel := req.Model mappedModel := req.Model
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(req.Model) mappedModel = account.GetMappedModel(req.Model)
} }
@@ -254,13 +253,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
var requestIDHeader string var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error) var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream useUpstreamStream := req.Stream
if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" { if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true useUpstreamStream = true
} }
switch account.Type { switch account.Type {
case model.AccountTypeApiKey: case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
@@ -291,7 +290,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
requestIDHeader = "x-request-id" requestIDHeader = "x-request-id"
case model.AccountTypeOAuth: case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil { if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured") return nil, "", errors.New("gemini token provider not configured")
@@ -476,7 +475,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil }, nil
} }
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
if strings.TrimSpace(originalModel) == "" { if strings.TrimSpace(originalModel) == "" {
@@ -497,7 +496,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
mappedModel := originalModel mappedModel := originalModel
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(originalModel) mappedModel = account.GetMappedModel(originalModel)
} }
@@ -508,7 +507,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
useUpstreamStream := stream useUpstreamStream := stream
upstreamAction := action upstreamAction := action
if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" { if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true useUpstreamStream = true
upstreamAction = "streamGenerateContent" upstreamAction = "streamGenerateContent"
@@ -519,7 +518,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error) var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type { switch account.Type {
case model.AccountTypeApiKey: case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" { if strings.TrimSpace(apiKey) == "" {
@@ -546,7 +545,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
requestIDHeader = "x-request-id" requestIDHeader = "x-request-id"
case model.AccountTypeOAuth: case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) { buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil { if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured") return nil, "", errors.New("gemini token provider not configured")
@@ -704,7 +703,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
c.Header("x-request-id", requestID) c.Header("x-request-id", requestID)
} }
isOAuth := account.Type == model.AccountTypeOAuth isOAuth := account.Type == AccountTypeOAuth
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@@ -776,13 +775,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil }, nil
} }
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool { func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode { switch statusCode {
case 429, 500, 502, 503, 504, 529: case 429, 500, 502, 503, 504, 529:
return true return true
case 403: case 403:
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry. // GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
if account == nil || account.Type != model.AccountTypeOAuth { if account == nil || account.Type != AccountTypeOAuth {
return false return false
} }
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type"))) oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
@@ -1599,7 +1598,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
// endpoints like /v1beta/models and /v1beta/models/{model}. // endpoints like /v1beta/models and /v1beta/models/{model}.
// //
// This is used to support Gemini SDKs that call models listing endpoints before generation. // This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) { func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
if account == nil { if account == nil {
return nil, errors.New("account is nil") return nil, errors.New("account is nil")
} }
@@ -1625,13 +1624,13 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
} }
switch account.Type { switch account.Type {
case model.AccountTypeApiKey: case AccountTypeApiKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key")) apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" { if apiKey == "" {
return nil, errors.New("gemini api_key not configured") return nil, errors.New("gemini api_key not configured")
} }
req.Header.Set("x-goog-api-key", apiKey) req.Header.Set("x-goog-api-key", apiKey)
case model.AccountTypeOAuth: case AccountTypeOAuth:
if s.tokenProvider == nil { if s.tokenProvider == nil {
return nil, errors.New("gemini token provider not configured") return nil, errors.New("gemini token provider not configured")
} }
@@ -1766,7 +1765,7 @@ func asInt(v any) (int, bool) {
} }
} }
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) { func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return return

View File

@@ -13,7 +13,6 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
) )
@@ -304,8 +303,8 @@ func isNonRetryableGeminiOAuthError(err error) bool {
return false return false
} }
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) {
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("account is not a Gemini OAuth account") return nil, fmt.Errorf("account is not a Gemini OAuth account")
} }

View File

@@ -8,7 +8,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
const ( const (
@@ -34,11 +33,11 @@ func NewGeminiTokenProvider(
} }
} }
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model.Account) (string, error) { func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil { if account == nil {
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return "", errors.New("not a gemini oauth account") return "", errors.New("not a gemini oauth account")
} }
@@ -83,7 +82,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
newCredentials[k] = v newCredentials[k] = v
} }
} }
account.Credentials = model.JSONB(newCredentials) account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account) _ = p.accountRepo.Update(ctx, account)
expiresAt = parseExpiresAt(account) expiresAt = parseExpiresAt(account)
} }
@@ -122,7 +121,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
detected = strings.TrimSpace(detected) detected = strings.TrimSpace(detected)
if detected != "" { if detected != "" {
if account.Credentials == nil { if account.Credentials == nil {
account.Credentials = model.JSONB{} account.Credentials = make(map[string]any)
} }
account.Credentials["project_id"] = detected account.Credentials["project_id"] = detected
_ = p.accountRepo.Update(ctx, account) _ = p.accountRepo.Update(ctx, account)
@@ -149,7 +148,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
return accessToken, nil return accessToken, nil
} }
func geminiTokenCacheKey(account *model.Account) string { func geminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
return projectID return projectID
@@ -157,7 +156,7 @@ func geminiTokenCacheKey(account *model.Account) string {
return "account:" + strconv.FormatInt(account.ID, 10) return "account:" + strconv.FormatInt(account.ID, 10)
} }
func parseExpiresAt(account *model.Account) *time.Time { func parseExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at")) raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" { if raw == "" {
return nil return nil

View File

@@ -5,7 +5,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
type GeminiTokenRefresher struct { type GeminiTokenRefresher struct {
@@ -16,11 +15,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService} return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
} }
func (r *GeminiTokenRefresher) CanRefresh(account *model.Account) bool { func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == model.PlatformGemini && account.Type == model.AccountTypeOAuth return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
} }
func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
if !r.CanRefresh(account) { if !r.CanRefresh(account) {
return false return false
} }
@@ -36,7 +35,7 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
return time.Until(expiryTime) < refreshWindow return time.Until(expiryTime) < refreshWindow
} }
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) { func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account) tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
return nil, err return nil, err