test(gemini): 添加 Drive API 和 OAuth 服务单元测试

- 新增 drive_client_test.go:Drive API 客户端单元测试
- 新增 gemini_oauth_service_test.go:OAuth 服务单元测试
- 重构 account_handler.go:改进 RefreshTier API 实现
- 优化 drive_client.go:增强错误处理和重试逻辑
- 完善 repository 和 service 层:支持批量 tier 刷新
- 更新迁移文件编号:017 -> 024(避免冲突)
This commit is contained in:
IanShaw027
2026-01-01 15:07:16 +08:00
parent 34bbfb5dd2
commit 48764e15a5
9 changed files with 383 additions and 118 deletions

View File

@@ -3,7 +3,7 @@ package admin
import (
"strconv"
"strings"
"time"
"sync"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
// OAuthHandler handles OAuth-related operations for accounts
@@ -1000,47 +1001,33 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
ctx := c.Request.Context()
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
if account.Credentials == nil || account.Credentials["oauth_type"] != "google_one" {
response.BadRequest(c, "Account is not a google_one OAuth account")
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
return
}
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
response.BadRequest(c, "Missing access_token in credentials")
oauthType, _ := account.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
return
}
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(c.Request.Context(), accessToken, proxyURL)
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
if err != nil {
response.ErrorFrom(c, err)
return
}
if account.Extra == nil {
account.Extra = make(map[string]any)
}
if storageInfo != nil {
account.Extra["drive_storage_limit"] = storageInfo.Limit
account.Extra["drive_storage_usage"] = storageInfo.Usage
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
}
account.Credentials["tier_id"] = tierID
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: account.Credentials,
Extra: account.Extra,
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
})
if updateErr != nil {
response.ErrorFrom(c, updateErr)
@@ -1049,9 +1036,10 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
response.Success(c, gin.H{
"tier_id": tierID,
"drive_storage_limit": account.Extra["drive_storage_limit"],
"drive_storage_usage": account.Extra["drive_storage_usage"],
"updated_at": account.Extra["drive_tier_updated_at"],
"storage_info": extra,
"drive_storage_limit": extra["drive_storage_limit"],
"drive_storage_usage": extra["drive_storage_usage"],
"updated_at": extra["drive_tier_updated_at"],
})
}
@@ -1069,7 +1057,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
}
ctx := c.Request.Context()
var accounts []service.Account
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
@@ -1077,84 +1065,87 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
for _, acc := range allAccounts {
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" {
for i := range allAccounts {
acc := &allAccounts[i]
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType == "google_one" {
accounts = append(accounts, acc)
}
}
} else {
for _, id := range req.AccountIDs {
acc, err := h.adminService.GetAccount(ctx, id)
if err != nil {
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
for _, acc := range fetched {
if acc == nil {
continue
}
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" {
accounts = append(accounts, *acc)
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
continue
}
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
continue
}
accounts = append(accounts, acc)
}
}
total := len(accounts)
success := 0
failed := 0
errors := []gin.H{}
const maxConcurrency = 10
g, gctx := errgroup.WithContext(ctx)
g.SetLimit(maxConcurrency)
var mu sync.Mutex
results := gin.H{
"total": len(accounts),
"success": 0,
"failed": 0,
"errors": []gin.H{},
}
for _, account := range accounts {
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
failed++
errors = append(errors, gin.H{
"account_id": account.ID,
"error": "missing access_token",
acc := account // 闭包捕获
g.Go(func() error {
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
if err != nil {
mu.Lock()
results["failed"] = results["failed"].(int) + 1
results["errors"] = append(results["errors"].([]gin.H), gin.H{
"account_id": acc.ID,
"error": err.Error(),
})
mu.Unlock()
return nil
}
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
})
continue
}
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
mu.Lock()
if updateErr != nil {
results["failed"] = results["failed"].(int) + 1
results["errors"] = append(results["errors"].([]gin.H), gin.H{
"account_id": acc.ID,
"error": updateErr.Error(),
})
} else {
results["success"] = results["success"].(int) + 1
}
mu.Unlock()
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(ctx, accessToken, proxyURL)
if err != nil {
failed++
errors = append(errors, gin.H{
"account_id": account.ID,
"error": err.Error(),
})
continue
}
if account.Extra == nil {
account.Extra = make(map[string]any)
}
if storageInfo != nil {
account.Extra["drive_storage_limit"] = storageInfo.Limit
account.Extra["drive_storage_usage"] = storageInfo.Usage
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
}
account.Credentials["tier_id"] = tierID
_, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
Credentials: account.Credentials,
Extra: account.Extra,
return nil
})
if updateErr != nil {
failed++
errors = append(errors, gin.H{
"account_id": account.ID,
"error": updateErr.Error(),
})
continue
}
success++
}
response.Success(c, gin.H{
"total": total,
"success": success,
"failed": failed,
"errors": errors,
})
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, results)
}