diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 0662df2e..5fa2f4e1 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -350,7 +350,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { newCredentials[k] = v } } - } else if account.Platform == model.PlatformGemini { + } else if account.Platform == service.PlatformGemini { tokenInfo, err := h.geminiOAuthService.RefreshAccountToken(c.Request.Context(), account) if err != nil { response.InternalError(c, "Failed to refresh credentials: "+err.Error()) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index b9b12be3..afb1c572 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -128,8 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 选择支持该模型的账号 - var account *model.Account - if platform == model.PlatformGemini { + var account *service.Account + if platform == service.PlatformGemini { account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) } else { account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model) @@ -162,7 +162,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 var result *service.ForwardResult - if platform == model.PlatformGemini { + if platform == service.PlatformGemini { result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) } else { result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 7f3ba5cc..6a9e2e15 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -25,7 +24,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { googleError(c, http.StatusUnauthorized, "Invalid API key") return } - if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { + if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { googleError(c, http.StatusBadRequest, "API key group platform is not gemini") return } @@ -56,7 +55,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusUnauthorized, "Invalid API key") return } - if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { + if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { googleError(c, http.StatusBadRequest, "API key group platform is not gemini") return } @@ -94,13 +93,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusUnauthorized, "Invalid API key") return } - user, ok := middleware.GetUserFromContext(c) - if !ok || user == nil { + authSubject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok { googleError(c, http.StatusInternalServerError, "User context not found") return } - if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { + if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { googleError(c, http.StatusBadRequest, "API key group platform is not gemini") return } @@ -130,19 +129,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) // 0) wait queue check - maxWait := service.CalculateMaxWait(user.Concurrency) - canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) + maxWait := service.CalculateMaxWait(authSubject.Concurrency) + canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait) if err != nil { log.Printf("Increment wait count failed: %v", err) } else if !canWait { googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } - defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), user.ID) + defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID) // 1) user concurrency slot streamStarted := false - userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, user, stream, &streamStarted) + userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { googleError(c, http.StatusTooManyRequests, err.Error()) return @@ -152,7 +151,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 2) billing eligibility check (after wait) - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { googleError(c, http.StatusForbidden, err.Error()) return } @@ -166,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 4) account concurrency slot - accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account, stream, &streamStarted) + accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted) if err != nil { googleError(c, http.StatusTooManyRequests, err.Error()) return @@ -190,7 +189,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ApiKey: apiKey, - User: user, + User: apiKey.User, Account: account, Subscription: subscription, }); err != nil { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 373ad4a9..a6b20d9d 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -70,6 +70,10 @@ func (a *Account) IsOAuth() bool { return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken } +func (a *Account) IsGemini() bool { + return a.Platform == PlatformGemini +} + func (a *Account) CanGetUsage() bool { return a.Type == AccountTypeOAuth } @@ -322,3 +326,17 @@ func (a *Account) IsOpenAITokenExpired() bool { } return time.Now().Add(60 * time.Second).After(*expiresAt) } + +// mergeJSONB merges source map into target map (for preserving extra fields during account sync) +func mergeJSONB(target, source map[string]any) map[string]any { + if target == nil { + target = make(map[string]any) + } + if source == nil { + return target + } + for k, v := range source { + target[k] = v + } + return target +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 8e29b3f9..a2dd9ac2 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -387,7 +387,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } // testGeminiAccountConnection tests a Gemini account's connection -func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error { +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { ctx := c.Request.Context() // Determine the model to use @@ -397,7 +397,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account } // For API Key accounts with model mapping, map the model - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { mapping := account.GetModelMapping() if len(mapping) > 0 { if mappedModel, exists := mapping[testModelID]; exists { @@ -421,9 +421,9 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account var err error switch account.Type { - case model.AccountTypeApiKey: + case AccountTypeApiKey: req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) - case model.AccountTypeOAuth: + case AccountTypeOAuth: req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) default: return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) @@ -458,7 +458,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account } // buildGeminiAPIKeyRequest builds request for Gemini API Key accounts -func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) { +func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { apiKey := account.GetCredential("api_key") if strings.TrimSpace(apiKey) == "" { return nil, fmt.Errorf("no API key available") @@ -485,7 +485,7 @@ func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, accou } // buildGeminiOAuthRequest builds request for Gemini OAuth accounts -func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) { +func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { if s.geminiTokenProvider == nil { return nil, fmt.Errorf("gemini token provider not configured") } diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 0e50b649..fa7bad21 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -772,12 +772,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { - account := &model.Account{ + account := &Account{ Name: defaultName(src.Name, src.ID), - Platform: model.PlatformGemini, - Type: model.AccountTypeOAuth, - Credentials: model.JSONB(credentials), - Extra: model.JSONB(extra), + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: credentials, + Extra: extra, ProxyID: proxyID, Concurrency: 3, Priority: clampPriority(src.Priority), @@ -803,8 +803,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput existing.Extra = mergeJSONB(existing.Extra, extra) existing.Name = defaultName(src.Name, src.ID) - existing.Platform = model.PlatformGemini - existing.Type = model.AccountTypeOAuth + existing.Platform = PlatformGemini + existing.Type = AccountTypeOAuth existing.Credentials = mergeJSONB(existing.Credentials, credentials) if proxyID != nil { existing.ProxyID = proxyID @@ -883,12 +883,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { - account := &model.Account{ + account := &Account{ Name: defaultName(src.Name, src.ID), - Platform: model.PlatformGemini, - Type: model.AccountTypeApiKey, - Credentials: model.JSONB(credentials), - Extra: model.JSONB(extra), + Platform: PlatformGemini, + Type: AccountTypeApiKey, + Credentials: credentials, + Extra: extra, ProxyID: proxyID, Concurrency: 3, Priority: clampPriority(src.Priority), @@ -910,8 +910,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput existing.Extra = mergeJSONB(existing.Extra, extra) existing.Name = defaultName(src.Name, src.ID) - existing.Platform = model.PlatformGemini - existing.Type = model.AccountTypeApiKey + existing.Platform = PlatformGemini + existing.Type = AccountTypeApiKey existing.Credentials = mergeJSONB(existing.Credentials, credentials) if proxyID != nil { existing.ProxyID = proxyID @@ -1200,7 +1200,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account } } } - case model.PlatformGemini: + case PlatformGemini: if s.geminiOAuthService == nil { return nil } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index f35e8254..e2462f3a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -18,7 +18,6 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" @@ -62,31 +61,31 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider { return s.tokenProvider } -func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { +func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { cacheKey := "gemini:" + sessionHash if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) if err == nil && accountID > 0 { account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) return account, nil } } } - var accounts []model.Account + var accounts []Account var err error if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } - var selected *model.Account + var selected *Account for i := range accounts { acc := &accounts[i] if requestedModel != "" && !acc.IsModelSupported(requestedModel) { @@ -106,7 +105,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: // Prefer OAuth accounts when both are unused (more compatible for Code Assist flows). - if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth { + if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth { selected = acc } default: @@ -139,13 +138,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, // 2) OAuth accounts without project_id (AI Studio OAuth) // 3) OAuth accounts explicitly marked as ai_studio // 4) Any remaining Gemini accounts (fallback) -func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*model.Account, error) { - var accounts []model.Account +func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx context.Context, groupID *int64) (*Account, error) { + var accounts []Account var err error if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformGemini) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformGemini) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -154,17 +153,17 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont return nil, errors.New("no available Gemini accounts") } - rank := func(a *model.Account) int { + rank := func(a *Account) int { if a == nil { return 999 } switch a.Type { - case model.AccountTypeApiKey: + case AccountTypeApiKey: if strings.TrimSpace(a.GetCredential("api_key")) != "" { return 0 } return 9 - case model.AccountTypeOAuth: + case AccountTypeOAuth: if strings.TrimSpace(a.GetCredential("project_id")) == "" { return 1 } @@ -178,7 +177,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont } } - var selected *model.Account + var selected *Account for i := range accounts { acc := &accounts[i] if selected == nil { @@ -204,7 +203,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if acc.Type == model.AccountTypeOAuth && selected.Type != model.AccountTypeOAuth { + if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth { selected = acc } default: @@ -221,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont return selected, nil } -func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { +func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { startTime := time.Now() var req struct { @@ -237,7 +236,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex originalModel := req.Model mappedModel := req.Model - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { mappedModel = account.GetMappedModel(req.Model) } @@ -254,13 +253,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex var requestIDHeader string var buildReq func(ctx context.Context) (*http.Request, string, error) useUpstreamStream := req.Stream - if account.Type == model.AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" { + if account.Type == AccountTypeOAuth && !req.Stream && strings.TrimSpace(account.GetCredential("project_id")) != "" { // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. useUpstreamStream = true } switch account.Type { - case model.AccountTypeApiKey: + case AccountTypeApiKey: buildReq = func(ctx context.Context) (*http.Request, string, error) { apiKey := account.GetCredential("api_key") if strings.TrimSpace(apiKey) == "" { @@ -291,7 +290,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } requestIDHeader = "x-request-id" - case model.AccountTypeOAuth: + case AccountTypeOAuth: buildReq = func(ctx context.Context) (*http.Request, string, error) { if s.tokenProvider == nil { return nil, "", errors.New("gemini token provider not configured") @@ -476,7 +475,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex }, nil } -func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *model.Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { +func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { startTime := time.Now() if strings.TrimSpace(originalModel) == "" { @@ -497,7 +496,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } mappedModel := originalModel - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { mappedModel = account.GetMappedModel(originalModel) } @@ -508,7 +507,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. useUpstreamStream := stream upstreamAction := action - if account.Type == model.AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" { + if account.Type == AccountTypeOAuth && !stream && action == "generateContent" && strings.TrimSpace(account.GetCredential("project_id")) != "" { // Code Assist's non-streaming generateContent may return no content; use streaming upstream and aggregate. useUpstreamStream = true upstreamAction = "streamGenerateContent" @@ -519,7 +518,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. var buildReq func(ctx context.Context) (*http.Request, string, error) switch account.Type { - case model.AccountTypeApiKey: + case AccountTypeApiKey: buildReq = func(ctx context.Context) (*http.Request, string, error) { apiKey := account.GetCredential("api_key") if strings.TrimSpace(apiKey) == "" { @@ -546,7 +545,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. } requestIDHeader = "x-request-id" - case model.AccountTypeOAuth: + case AccountTypeOAuth: buildReq = func(ctx context.Context) (*http.Request, string, error) { if s.tokenProvider == nil { return nil, "", errors.New("gemini token provider not configured") @@ -704,7 +703,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. c.Header("x-request-id", requestID) } - isOAuth := account.Type == model.AccountTypeOAuth + isOAuth := account.Type == AccountTypeOAuth if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -776,13 +775,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } -func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool { +func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool { switch statusCode { case 429, 500, 502, 503, 504, 529: return true case 403: // GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry. - if account == nil || account.Type != model.AccountTypeOAuth { + if account == nil || account.Type != AccountTypeOAuth { return false } oauthType := strings.ToLower(strings.TrimSpace(account.GetCredential("oauth_type"))) @@ -1599,7 +1598,7 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte // endpoints like /v1beta/models and /v1beta/models/{model}. // // This is used to support Gemini SDKs that call models listing endpoints before generation. -func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *model.Account, path string) (*UpstreamHTTPResult, error) { +func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, account *Account, path string) (*UpstreamHTTPResult, error) { if account == nil { return nil, errors.New("account is nil") } @@ -1625,13 +1624,13 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac } switch account.Type { - case model.AccountTypeApiKey: + case AccountTypeApiKey: apiKey := strings.TrimSpace(account.GetCredential("api_key")) if apiKey == "" { return nil, errors.New("gemini api_key not configured") } req.Header.Set("x-goog-api-key", apiKey) - case model.AccountTypeOAuth: + case AccountTypeOAuth: if s.tokenProvider == nil { return nil, errors.New("gemini token provider not configured") } @@ -1766,7 +1765,7 @@ func asInt(v any) (int, bool) { } } -func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) { +func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) return diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 89d0f3ba..36257667 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -13,7 +13,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" ) @@ -304,8 +303,8 @@ func isNonRetryableGeminiOAuthError(err error) bool { return false } -func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) { - if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { +func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*GeminiTokenInfo, error) { + if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth { return nil, fmt.Errorf("account is not a Gemini OAuth account") } diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index c75198ad..20ff378a 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/model" ) const ( @@ -34,11 +33,11 @@ func NewGeminiTokenProvider( } } -func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model.Account) (string, error) { +func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") } - if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { + if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth { return "", errors.New("not a gemini oauth account") } @@ -83,7 +82,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model newCredentials[k] = v } } - account.Credentials = model.JSONB(newCredentials) + account.Credentials = newCredentials _ = p.accountRepo.Update(ctx, account) expiresAt = parseExpiresAt(account) } @@ -122,7 +121,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model detected = strings.TrimSpace(detected) if detected != "" { if account.Credentials == nil { - account.Credentials = model.JSONB{} + account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected _ = p.accountRepo.Update(ctx, account) @@ -149,7 +148,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model return accessToken, nil } -func geminiTokenCacheKey(account *model.Account) string { +func geminiTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { return projectID @@ -157,7 +156,7 @@ func geminiTokenCacheKey(account *model.Account) string { return "account:" + strconv.FormatInt(account.ID, 10) } -func parseExpiresAt(account *model.Account) *time.Time { +func parseExpiresAt(account *Account) *time.Time { raw := strings.TrimSpace(account.GetCredential("expires_at")) if raw == "" { return nil diff --git a/backend/internal/service/gemini_token_refresher.go b/backend/internal/service/gemini_token_refresher.go index 25ad699d..b67ba832 100644 --- a/backend/internal/service/gemini_token_refresher.go +++ b/backend/internal/service/gemini_token_refresher.go @@ -5,7 +5,6 @@ import ( "strconv" "time" - "github.com/Wei-Shaw/sub2api/internal/model" ) type GeminiTokenRefresher struct { @@ -16,11 +15,11 @@ func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiToke return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService} } -func (r *GeminiTokenRefresher) CanRefresh(account *model.Account) bool { - return account.Platform == model.PlatformGemini && account.Type == model.AccountTypeOAuth +func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool { + return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth } -func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { +func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { if !r.CanRefresh(account) { return false } @@ -36,7 +35,7 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo return time.Until(expiryTime) < refreshWindow } -func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) { +func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account) if err != nil { return nil, err