fix(backend): 适配重构后的架构修复 Gemini OAuth 集成

## 主要修改

1. **移除 model 包引用**
   - 删除所有 `internal/model` 包的 import
   - 使用 service 包中的类型定义(Account, Platform常量等)

2. **修复类型转换**
   - JSONB → map[string]any
   - 添加 mergeJSONB 辅助函数
   - 添加 Account.IsGemini() 方法

3. **更新中间件调用**
   - GetUserFromContext → GetAuthSubjectFromContext
   - 适配新的并发控制签名(传递 ID 和 Concurrency 而不是完整对象)

4. **修复 handler 层**
   - 更新 gemini_v1beta_handler.go
   - 修正 billing 检查和 usage 记录

## 影响范围
- backend/internal/service/gemini_*.go
- backend/internal/service/account_test_service.go
- backend/internal/service/crs_sync_service.go
- backend/internal/handler/gemini_v1beta_handler.go
- backend/internal/handler/gateway_handler.go
- backend/internal/handler/admin/account_handler.go
This commit is contained in:
IanShaw027
2025-12-26 22:07:55 +08:00
parent bfcd9501c2
commit 9db52838b5
10 changed files with 100 additions and 87 deletions

View File

@@ -350,7 +350,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
newCredentials[k] = v
}
}
} else if account.Platform == model.PlatformGemini {
} else if account.Platform == service.PlatformGemini {
tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())

View File

@@ -128,8 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 选择支持该模型的账号
var account *model.Account
if platform == model.PlatformGemini {
var account *service.Account
if platform == service.PlatformGemini {
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
} else {
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
@@ -162,7 +162,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求
var result *service.ForwardResult
if platform == model.PlatformGemini {
if platform == service.PlatformGemini {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -25,7 +24,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
@@ -56,7 +55,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
@@ -94,13 +93,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok || user == nil {
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
@@ -130,19 +129,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), user.ID)
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot
streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, user, stream, &streamStarted)
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
@@ -152,7 +151,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
return
}
@@ -166,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account, stream, &streamStarted)
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
@@ -190,7 +189,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {

View File

@@ -70,6 +70,10 @@ func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
@@ -322,3 +326,17 @@ func (a *Account) IsOpenAITokenExpired() bool {
}
return time.Now().Add(60 * time.Second).After(*expiresAt)
}
// mergeJSONB merges source map into target map (for preserving extra fields during account sync)
func mergeJSONB(target, source map[string]any) map[string]any {
if target == nil {
target = make(map[string]any)
}
if source == nil {
return target
}
for k, v := range source {
target[k] = v
}
return target
}

View File

@@ -387,7 +387,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
}
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
// Determine the model to use
@@ -397,7 +397,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -421,9 +421,9 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case model.AccountTypeOAuth:
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
@@ -458,7 +458,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
return nil, fmt.Errorf("no API key available")
@@ -485,7 +485,7 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou
}
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
if s.geminiTokenProvider == nil {
return nil, fmt.Errorf("gemini token provider not configured")
}

View File

@@ -772,12 +772,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformGemini,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: 3,
Priority: clampPriority(src.Priority),
@@ -803,8 +803,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformGemini
existing.Type = model.AccountTypeOAuth
existing.Platform = PlatformGemini
existing.Type = AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -883,12 +883,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformGemini,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformGemini,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: 3,
Priority: clampPriority(src.Priority),
@@ -910,8 +910,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformGemini
existing.Type = model.AccountTypeApiKey
existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -1200,7 +1200,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account
}
}
}
case model.PlatformGemini:
case PlatformGemini:
if s.geminiOAuthService == nil {
return nil
}

View File

@@ -18,7 +18,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
@@ -62,31 +61,31 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
return s.tokenProvider
}
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
@@ -106,7 +105,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
@@ -139,13 +138,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
// 2) OAuth accounts without project_id (AI Studio OAuth)
// 3) OAuth accounts explicitly marked as ai_studio
// 4) Any remaining Gemini accounts (fallback)
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) {
var accounts []model.Account
func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -154,17 +153,17 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return nil, errors.New("no available Gemini accounts")
}
rank := func(a *model.Account) int {
rank := func(a *Account) int {
if a == nil {
return 999
}
switch a.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
return 9
case model.AccountTypeOAuth:
case AccountTypeOAuth:
if strings.TrimSpace(a.GetCredential("project_id")) == "" {
return 1
}
@@ -178,7 +177,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
}
}
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
if selected == nil {
@@ -204,7 +203,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth {
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
@@ -221,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return selected, nil
}
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
var req struct {
@@ -237,7 +236,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -254,13 +253,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
var requestIDHeader string
var buildReq func(ctx context.Context) (*http.Request, string, error)
useUpstreamStream := req.Stream
if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
}
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -291,7 +290,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = "x-request-id"
case model.AccountTypeOAuth:
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
@@ -476,7 +475,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}, nil
}
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
if strings.TrimSpace(originalModel) == "" {
@@ -497,7 +496,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -508,7 +507,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
useUpstreamStream := stream
upstreamAction := action
if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" {
// Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate.
useUpstreamStream = true
upstreamAction = "streamGenerateContent"
@@ -519,7 +518,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -546,7 +545,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = "x-request-id"
case model.AccountTypeOAuth:
case AccountTypeOAuth:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
if s.tokenProvider == nil {
return nil, "", errors.New("gemini token provider not configured")
@@ -704,7 +703,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
c.Header("x-request-id", requestID)
}
isOAuth := account.Type == model.AccountTypeOAuth
isOAuth := account.Type == AccountTypeOAuth
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@@ -776,13 +775,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool {
func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool {
switch statusCode {
case 429, 500, 502, 503, 504, 529:
return true
case 403:
// GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry.
if account == nil || account.Type != model.AccountTypeOAuth {
if account == nil || account.Type != AccountTypeOAuth {
return false
}
oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type")))
@@ -1599,7 +1598,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
// endpoints like /v1beta/models and /v1beta/models/{model}.
//
// This is used to support Gemini SDKs that call models listing endpoints before generation.
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) {
func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) {
if account == nil {
return nil, errors.New("account is nil")
}
@@ -1625,13 +1624,13 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
}
req.Header.Set("x-goog-api-key", apiKey)
case model.AccountTypeOAuth:
case AccountTypeOAuth:
if s.tokenProvider == nil {
return nil, errors.New("gemini token provider not configured")
}
@@ -1766,7 +1765,7 @@ func asInt(v any) (int, bool) {
}
}
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) {
func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
return

View File

@@ -13,7 +13,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
)
@@ -304,8 +303,8 @@ func isNonRetryableGeminiOAuthError(err error) bool {
return false
}
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) {
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth {
func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) {
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("account is not a Gemini OAuth account")
}

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
const (
@@ -34,11 +33,11 @@ func NewGeminiTokenProvider(
}
}
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model.Account) (string, error) {
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth {
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
return "", errors.New("not a gemini oauth account")
}
@@ -83,7 +82,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
newCredentials[k] = v
}
}
account.Credentials = model.JSONB(newCredentials)
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = parseExpiresAt(account)
}
@@ -122,7 +121,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
detected = strings.TrimSpace(detected)
if detected != "" {
if account.Credentials == nil {
account.Credentials = model.JSONB{}
account.Credentials = make(map[string]any)
}
account.Credentials["project_id"] = detected
_ = p.accountRepo.Update(ctx, account)
@@ -149,7 +148,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model
return accessToken, nil
}
func geminiTokenCacheKey(account *model.Account) string {
func geminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return projectID
@@ -157,7 +156,7 @@ func geminiTokenCacheKey(account *model.Account) string {
return "account:" + strconv.FormatInt(account.ID, 10)
}
func parseExpiresAt(account *model.Account) *time.Time {
func parseExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil

View File

@@ -5,7 +5,6 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
type GeminiTokenRefresher struct {
@@ -16,11 +15,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
}
func (r *GeminiTokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformGemini && account.Type == model.AccountTypeOAuth
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
}
func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
@@ -36,7 +35,7 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
return time.Until(expiryTime) < refreshWindow
}
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err