fix(backend): 适配重构后的架构修复 Gemini OAuth 集成
## 主要修改 1. **移除 model 包引用** - 删除所有 `internal/model` 包的 import - 使用 service 包中的类型定义(Account, Platform常量等) 2. **修复类型转换** - JSONB → map[string]any - 添加 mergeJSONB 辅助函数 - 添加 Account.IsGemini() 方法 3. **更新中间件调用** - GetUserFromContext → GetAuthSubjectFromContext - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象) 4. **修复 handler 层** - 更新 gemini_v1beta_handler.go - 修正 billing 检查和 usage 记录 ## 影响范围 - backend/internal/service/gemini_*.go - backend/internal/service/account_test_service.go - backend/internal/service/crs_sync_service.go - backend/internal/handler/gemini_v1beta_handler.go - backend/internal/handler/gateway_handler.go - backend/internal/handler/admin/account_handler.go
This commit is contained in:
@@ -350,7 +350,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
} else if account.Platform == model.PlatformGemini {
|
||||
} else if account.Platform == service.PlatformGemini {
|
||||
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
|
||||
|
||||
@@ -128,8 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 选择支持该模型的账号
|
||||
var account *model.Account
|
||||
if platform == model.PlatformGemini {
|
||||
var account *service.Account
|
||||
if platform == service.PlatformGemini {
|
||||
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
} else {
|
||||
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
@@ -162,7 +162,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 转发请求
|
||||
var result *service.ForwardResult
|
||||
if platform == model.PlatformGemini {
|
||||
if platform == service.PlatformGemini {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -25,7 +24,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
@@ -56,7 +55,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
@@ -94,13 +93,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||
return
|
||||
}
|
||||
user, ok := middleware.GetUserFromContext(c)
|
||||
if !ok || user == nil {
|
||||
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
|
||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||
return
|
||||
}
|
||||
@@ -130,19 +129,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
|
||||
|
||||
// 0) wait queue check
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
|
||||
// 1) user concurrency slot
|
||||
streamStarted := false
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, user, stream, &streamStarted)
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
@@ -152,7 +151,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
googleError(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -166,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account, stream, &streamStarted)
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
@@ -190,7 +189,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
|
||||
@@ -70,6 +70,10 @@ func (a *Account) IsOAuth() bool {
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
|
||||
}
|
||||
|
||||
func (a *Account) IsGemini() bool {
|
||||
return a.Platform == PlatformGemini
|
||||
}
|
||||
|
||||
func (a *Account) CanGetUsage() bool {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
@@ -322,3 +326,17 @@ func (a *Account) IsOpenAITokenExpired() bool {
|
||||
}
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
|
||||
// mergeJSONB merges source map into target map (for preserving extra fields during account sync)
|
||||
func mergeJSONB(target, source map[string]any) map[string]any {
|
||||
if target == nil {
|
||||
target = make(map[string]any)
|
||||
}
|
||||
if source == nil {
|
||||
return target
|
||||
}
|
||||
for k, v := range source {
|
||||
target[k] = v
|
||||
}
|
||||
return target
|
||||
}
|
||||
|
||||
@@ -387,7 +387,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
@@ -397,7 +397,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -421,9 +421,9 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
@@ -458,7 +458,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
|
||||
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil, fmt.Errorf("no API key available")
|
||||
@@ -485,7 +485,7 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
|
||||
}
|
||||
|
||||
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts
|
||||
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, fmt.Errorf("gemini token provider not configured")
|
||||
}
|
||||
|
||||
@@ -772,12 +772,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformGemini,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: 3,
|
||||
Priority: clampPriority(src.Priority),
|
||||
@@ -803,8 +803,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformGemini
|
||||
existing.Type = model.AccountTypeOAuth
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeOAuth
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -883,12 +883,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformGemini,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeApiKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: 3,
|
||||
Priority: clampPriority(src.Priority),
|
||||
@@ -910,8 +910,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformGemini
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -1200,7 +1200,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
|
||||
}
|
||||
}
|
||||
}
|
||||
case model.PlatformGemini:
|
||||
case PlatformGemini:
|
||||
if s.geminiOAuthService == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
|
||||
@@ -62,31 +61,31 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
|
||||
return s.tokenProvider
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var accounts []model.Account
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
@@ -106,7 +105,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
||||
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
@@ -139,13 +138,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
// 2) OAuth accounts without project_id (AI Studio OAuth)
|
||||
// 3) OAuth accounts explicitly marked as ai_studio
|
||||
// 4) Any remaining Gemini accounts (fallback)
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) {
|
||||
var accounts []model.Account
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -154,17 +153,17 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return nil, errors.New("no available Gemini accounts")
|
||||
}
|
||||
|
||||
rank := func(a *model.Account) int {
|
||||
rank := func(a *Account) int {
|
||||
if a == nil {
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
return 9
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
|
||||
return 1
|
||||
}
|
||||
@@ -178,7 +177,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
}
|
||||
}
|
||||
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if selected == nil {
|
||||
@@ -204,7 +203,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
@@ -221,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
var req struct {
|
||||
@@ -237,7 +236,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -254,13 +253,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
var requestIDHeader string
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
useUpstreamStream := req.Stream
|
||||
if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
|
||||
useUpstreamStream = true
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -291,7 +290,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
@@ -476,7 +475,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
if strings.TrimSpace(originalModel) == "" {
|
||||
@@ -497,7 +496,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -508,7 +507,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
useUpstreamStream := stream
|
||||
upstreamAction := action
|
||||
if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
|
||||
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
|
||||
useUpstreamStream = true
|
||||
upstreamAction = "streamGenerateContent"
|
||||
@@ -519,7 +518,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -546,7 +545,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = "x-request-id"
|
||||
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
if s.tokenProvider == nil {
|
||||
return nil, "", errors.New("gemini token provider not configured")
|
||||
@@ -704,7 +703,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
isOAuth := account.Type == model.AccountTypeOAuth
|
||||
isOAuth := account.Type == AccountTypeOAuth
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
@@ -776,13 +775,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool {
|
||||
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 429, 500, 502, 503, 504, 529:
|
||||
return true
|
||||
case 403:
|
||||
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
|
||||
if account == nil || account.Type != model.AccountTypeOAuth {
|
||||
if account == nil || account.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
|
||||
@@ -1599,7 +1598,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
||||
// endpoints like /v1beta/models and /v1beta/models/{model}.
|
||||
//
|
||||
// This is used to support Gemini SDKs that call models listing endpoints before generation.
|
||||
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) {
|
||||
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
|
||||
if account == nil {
|
||||
return nil, errors.New("account is nil")
|
||||
}
|
||||
@@ -1625,13 +1624,13 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("gemini token provider not configured")
|
||||
}
|
||||
@@ -1766,7 +1765,7 @@ func asInt(v any) (int, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) {
|
||||
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
return
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
@@ -304,8 +303,8 @@ func isNonRetryableGeminiOAuthError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) {
|
||||
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth {
|
||||
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) {
|
||||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("account is not a Gemini OAuth account")
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,11 +33,11 @@ func NewGeminiTokenProvider(
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model.Account) (string, error) {
|
||||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth {
|
||||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a gemini oauth account")
|
||||
}
|
||||
|
||||
@@ -83,7 +82,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = model.JSONB(newCredentials)
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = parseExpiresAt(account)
|
||||
}
|
||||
@@ -122,7 +121,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
|
||||
detected = strings.TrimSpace(detected)
|
||||
if detected != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = model.JSONB{}
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["project_id"] = detected
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
@@ -149,7 +148,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func geminiTokenCacheKey(account *model.Account) string {
|
||||
func geminiTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return projectID
|
||||
@@ -157,7 +156,7 @@ func geminiTokenCacheKey(account *model.Account) string {
|
||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func parseExpiresAt(account *model.Account) *time.Time {
|
||||
func parseExpiresAt(account *Account) *time.Time {
|
||||
raw := strings.TrimSpace(account.GetCredential("expires_at"))
|
||||
if raw == "" {
|
||||
return nil
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
type GeminiTokenRefresher struct {
|
||||
@@ -16,11 +15,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
|
||||
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformGemini && account.Type == model.AccountTypeOAuth
|
||||
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
@@ -36,7 +35,7 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user