fix: 修复Oauth账号自动刷新token失败的bug
This commit is contained in:
@@ -73,6 +73,10 @@ func provideCleanup(
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -106,6 +106,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||
@@ -138,6 +139,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
Concurrency: concurrencyService,
|
||||
Identity: identityService,
|
||||
Update: updateService,
|
||||
TokenRefresh: tokenRefreshService,
|
||||
}
|
||||
repositories := &repository.Repositories{
|
||||
User: userRepository,
|
||||
@@ -187,6 +189,10 @@ func provideCleanup(
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{"TokenRefreshService", func() error {
|
||||
services.TokenRefresh.Stop()
|
||||
return nil
|
||||
}},
|
||||
{"PricingService", func() error {
|
||||
services.Pricing.Stop()
|
||||
return nil
|
||||
|
||||
@@ -8,15 +8,30 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
}
|
||||
|
||||
// TokenRefreshConfig OAuth token自动刷新配置
|
||||
type TokenRefreshConfig struct {
|
||||
// 是否启用自动刷新
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
// 检查间隔(分钟)
|
||||
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
|
||||
// 提前刷新时间(小时),在token过期前多久开始刷新
|
||||
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
|
||||
// 最大重试次数
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
// 重试退避基础时间(秒)
|
||||
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||
}
|
||||
|
||||
type PricingConfig struct {
|
||||
@@ -192,6 +207,13 @@ func setDefaults() {
|
||||
|
||||
// Gateway
|
||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||
|
||||
// TokenRefresh
|
||||
viper.SetDefault("token_refresh.enabled", true)
|
||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -34,7 +33,6 @@ const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
@@ -358,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
|
||||
// 检查是否需要刷新
|
||||
needRefresh := false
|
||||
if expiresAtStr != "" {
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
|
||||
needRefresh = true
|
||||
}
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
if needRefresh || accessToken == "" {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refresh token failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新账号凭证
|
||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, "oauth", nil
|
||||
}
|
||||
|
||||
// Token刷新由后台 TokenRefreshService 处理,此处只返回当前token
|
||||
return accessToken, "oauth", nil
|
||||
}
|
||||
|
||||
@@ -442,25 +413,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理401错误:刷新token重试
|
||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
||||
resp.Body.Close()
|
||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retry request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -619,25 +572,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo.AccessToken, "oauth", nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -1053,26 +987,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 处理 401 错误:刷新 token 重试(仅 OAuth)
|
||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
||||
resp.Body.Close()
|
||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
|
||||
return fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
|
||||
return fmt.Errorf("retry request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,4 +27,5 @@ type Services struct {
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
}
|
||||
|
||||
185
backend/internal/service/token_refresh_service.go
Normal file
185
backend/internal/service/token_refresh_service.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// TokenRefreshService OAuth token自动刷新服务
|
||||
// 定期检查并刷新即将过期的token
|
||||
type TokenRefreshService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
refreshers []TokenRefresher
|
||||
cfg *config.TokenRefreshConfig
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewTokenRefreshService 创建token刷新服务
|
||||
func NewTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
accountRepo: accountRepo,
|
||||
cfg: &cfg.TokenRefresh,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
// 未来可以添加其他平台的刷新器:
|
||||
// NewOpenAITokenRefresher(...),
|
||||
// NewGeminiTokenRefresher(...),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
log.Println("[TokenRefresh] Service disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.refreshLoop()
|
||||
|
||||
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
|
||||
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
|
||||
}
|
||||
|
||||
// Stop 停止刷新服务
|
||||
func (s *TokenRefreshService) Stop() {
|
||||
close(s.stopCh)
|
||||
s.wg.Wait()
|
||||
log.Println("[TokenRefresh] Service stopped")
|
||||
}
|
||||
|
||||
// refreshLoop 刷新循环
|
||||
func (s *TokenRefreshService) refreshLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
// 计算检查间隔
|
||||
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
|
||||
if checkInterval < time.Minute {
|
||||
checkInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// 启动时立即执行一次检查
|
||||
s.processRefresh()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.processRefresh()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processRefresh 执行一次刷新检查
|
||||
func (s *TokenRefreshService) processRefresh() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 计算刷新窗口
|
||||
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
|
||||
|
||||
// 获取所有active状态的账号
|
||||
accounts, err := s.listActiveAccounts(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
refreshed, failed := 0, 0
|
||||
|
||||
for i := range accounts {
|
||||
account := &accounts[i]
|
||||
|
||||
// 遍历所有刷新器,找到能处理此账号的
|
||||
for _, refresher := range s.refreshers {
|
||||
if !refresher.CanRefresh(account) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否需要刷新
|
||||
if !refresher.NeedsRefresh(account, refreshWindow) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行刷新
|
||||
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||
failed++
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
|
||||
refreshed++
|
||||
}
|
||||
|
||||
// 每个账号只由一个refresher处理
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if refreshed > 0 || failed > 0 {
|
||||
log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed)
|
||||
}
|
||||
}
|
||||
|
||||
// listActiveAccounts 获取所有active状态的账号
|
||||
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
|
||||
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
|
||||
return s.accountRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
// refreshWithRetry 带重试的刷新
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
account.Credentials = model.JSONB(newCredentials)
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
|
||||
account.ID, attempt, s.cfg.MaxRetries, err)
|
||||
|
||||
// 如果还有重试机会,等待后重试
|
||||
if attempt < s.cfg.MaxRetries {
|
||||
// 指数退避:2^(attempt-1) * baseSeconds
|
||||
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// 所有重试都失败,标记账号为error状态
|
||||
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
90
backend/internal/service/token_refresher.go
Normal file
90
backend/internal/service/token_refresher.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
)
|
||||
|
||||
// TokenRefresher 定义平台特定的token刷新策略接口
|
||||
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
|
||||
type TokenRefresher interface {
|
||||
// CanRefresh 检查此刷新器是否能处理指定账号
|
||||
CanRefresh(account *model.Account) bool
|
||||
|
||||
// NeedsRefresh 检查账号的token是否需要刷新
|
||||
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
|
||||
|
||||
// Refresh 执行token刷新,返回更新后的credentials
|
||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||
type ClaudeTokenRefresher struct {
|
||||
oauthService *OAuthService
|
||||
}
|
||||
|
||||
// NewClaudeTokenRefresher 创建Claude token刷新器
|
||||
func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
||||
return &ClaudeTokenRefresher{
|
||||
oauthService: oauthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 anthropic 平台的 oauth 类型账号
|
||||
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformAnthropic &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留现有credentials中的所有字段
|
||||
newCredentials := make(map[string]interface{})
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
|
||||
// 只更新token相关字段
|
||||
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.Scope != "" {
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
@@ -33,6 +33,17 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
return NewEmailQueueService(emailService, 3)
|
||||
}
|
||||
|
||||
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -61,6 +72,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewConcurrencyService,
|
||||
NewIdentityService,
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
|
||||
// Provide the Services container struct
|
||||
wire.Struct(new(Services), "*"),
|
||||
|
||||
Reference in New Issue
Block a user