test(gemini): 添加 Drive API 和 OAuth 服务单元测试
- 新增 drive_client_test.go:Drive API 客户端单元测试 - 新增 gemini_oauth_service_test.go:OAuth 服务单元测试 - 重构 account_handler.go:改进 RefreshTier API 实现 - 优化 drive_client.go:增强错误处理和重试逻辑 - 完善 repository 和 service 层:支持批量 tier 刷新 - 更新迁移文件编号:017 -> 024(避免冲突)
This commit is contained in:
@@ -3,7 +3,7 @@ package admin
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth-related operations for accounts
|
||||
@@ -1000,47 +1001,33 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
ctx := c.Request.Context()
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
if account.Credentials == nil || account.Credentials["oauth_type"] != "google_one" {
|
||||
response.BadRequest(c, "Account is not a google_one OAuth account")
|
||||
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
|
||||
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, ok := account.Credentials["access_token"].(string)
|
||||
if !ok || accessToken == "" {
|
||||
response.BadRequest(c, "Missing access_token in credentials")
|
||||
oauthType, _ := account.Credentials["oauth_type"].(string)
|
||||
if oauthType != "google_one" {
|
||||
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(c.Request.Context(), accessToken, proxyURL)
|
||||
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
if storageInfo != nil {
|
||||
account.Extra["drive_storage_limit"] = storageInfo.Limit
|
||||
account.Extra["drive_storage_usage"] = storageInfo.Usage
|
||||
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
|
||||
}
|
||||
account.Credentials["tier_id"] = tierID
|
||||
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
Extra: account.Extra,
|
||||
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
|
||||
Credentials: creds,
|
||||
Extra: extra,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.ErrorFrom(c, updateErr)
|
||||
@@ -1049,9 +1036,10 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"tier_id": tierID,
|
||||
"drive_storage_limit": account.Extra["drive_storage_limit"],
|
||||
"drive_storage_usage": account.Extra["drive_storage_usage"],
|
||||
"updated_at": account.Extra["drive_tier_updated_at"],
|
||||
"storage_info": extra,
|
||||
"drive_storage_limit": extra["drive_storage_limit"],
|
||||
"drive_storage_usage": extra["drive_storage_usage"],
|
||||
"updated_at": extra["drive_tier_updated_at"],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1069,7 +1057,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var accounts []service.Account
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||
@@ -1077,84 +1065,87 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
for _, acc := range allAccounts {
|
||||
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" {
|
||||
for i := range allAccounts {
|
||||
acc := &allAccounts[i]
|
||||
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||
if oauthType == "google_one" {
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, id := range req.AccountIDs {
|
||||
acc, err := h.adminService.GetAccount(ctx, id)
|
||||
if err != nil {
|
||||
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, acc := range fetched {
|
||||
if acc == nil {
|
||||
continue
|
||||
}
|
||||
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" {
|
||||
accounts = append(accounts, *acc)
|
||||
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
|
||||
continue
|
||||
}
|
||||
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||
if oauthType != "google_one" {
|
||||
continue
|
||||
}
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
}
|
||||
|
||||
total := len(accounts)
|
||||
success := 0
|
||||
failed := 0
|
||||
errors := []gin.H{}
|
||||
const maxConcurrency = 10
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(maxConcurrency)
|
||||
|
||||
var mu sync.Mutex
|
||||
results := gin.H{
|
||||
"total": len(accounts),
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"errors": []gin.H{},
|
||||
}
|
||||
|
||||
for _, account := range accounts {
|
||||
accessToken, ok := account.Credentials["access_token"].(string)
|
||||
if !ok || accessToken == "" {
|
||||
failed++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": account.ID,
|
||||
"error": "missing access_token",
|
||||
acc := account // 闭包捕获
|
||||
g.Go(func() error {
|
||||
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
results["failed"] = results["failed"].(int) + 1
|
||||
results["errors"] = append(results["errors"].([]gin.H), gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
|
||||
Credentials: creds,
|
||||
Extra: extra,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
mu.Lock()
|
||||
if updateErr != nil {
|
||||
results["failed"] = results["failed"].(int) + 1
|
||||
results["errors"] = append(results["errors"].([]gin.H), gin.H{
|
||||
"account_id": acc.ID,
|
||||
"error": updateErr.Error(),
|
||||
})
|
||||
} else {
|
||||
results["success"] = results["success"].(int) + 1
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
failed++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": account.ID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if account.Extra == nil {
|
||||
account.Extra = make(map[string]any)
|
||||
}
|
||||
if storageInfo != nil {
|
||||
account.Extra["drive_storage_limit"] = storageInfo.Limit
|
||||
account.Extra["drive_storage_usage"] = storageInfo.Usage
|
||||
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
|
||||
}
|
||||
account.Credentials["tier_id"] = tierID
|
||||
|
||||
_, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
Extra: account.Extra,
|
||||
return nil
|
||||
})
|
||||
if updateErr != nil {
|
||||
failed++
|
||||
errors = append(errors, gin.H{
|
||||
"account_id": account.ID,
|
||||
"error": updateErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"total": total,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, results)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user