test(gemini): 添加 Drive API 和 OAuth 服务单元测试
- 新增 drive_client_test.go:Drive API 客户端单元测试 - 新增 gemini_oauth_service_test.go:OAuth 服务单元测试 - 重构 account_handler.go:改进 RefreshTier API 实现 - 优化 drive_client.go:增强错误处理和重试逻辑 - 完善 repository 和 service 层:支持批量 tier 刷新 - 更新迁移文件编号:017 -> 024(避免冲突)
This commit is contained in:
@@ -3,7 +3,7 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"sync"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OAuthHandler handles OAuth-related operations for accounts
|
// OAuthHandler handles OAuth-related operations for accounts
|
||||||
@@ -1000,47 +1001,33 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
ctx := c.Request.Context()
|
||||||
|
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.NotFound(c, "Account not found")
|
response.NotFound(c, "Account not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Credentials == nil || account.Credentials["oauth_type"] != "google_one" {
|
if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
|
||||||
response.BadRequest(c, "Account is not a google_one OAuth account")
|
response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, ok := account.Credentials["access_token"].(string)
|
oauthType, _ := account.Credentials["oauth_type"].(string)
|
||||||
if !ok || accessToken == "" {
|
if oauthType != "google_one" {
|
||||||
response.BadRequest(c, "Missing access_token in credentials")
|
response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var proxyURL string
|
tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL = account.Proxy.URL()
|
|
||||||
}
|
|
||||||
|
|
||||||
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(c.Request.Context(), accessToken, proxyURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Extra == nil {
|
_, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
|
||||||
account.Extra = make(map[string]any)
|
Credentials: creds,
|
||||||
}
|
Extra: extra,
|
||||||
if storageInfo != nil {
|
|
||||||
account.Extra["drive_storage_limit"] = storageInfo.Limit
|
|
||||||
account.Extra["drive_storage_usage"] = storageInfo.Usage
|
|
||||||
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
account.Credentials["tier_id"] = tierID
|
|
||||||
|
|
||||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
|
||||||
Credentials: account.Credentials,
|
|
||||||
Extra: account.Extra,
|
|
||||||
})
|
})
|
||||||
if updateErr != nil {
|
if updateErr != nil {
|
||||||
response.ErrorFrom(c, updateErr)
|
response.ErrorFrom(c, updateErr)
|
||||||
@@ -1049,9 +1036,10 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"tier_id": tierID,
|
"tier_id": tierID,
|
||||||
"drive_storage_limit": account.Extra["drive_storage_limit"],
|
"storage_info": extra,
|
||||||
"drive_storage_usage": account.Extra["drive_storage_usage"],
|
"drive_storage_limit": extra["drive_storage_limit"],
|
||||||
"updated_at": account.Extra["drive_tier_updated_at"],
|
"drive_storage_usage": extra["drive_storage_usage"],
|
||||||
|
"updated_at": extra["drive_tier_updated_at"],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1069,7 +1057,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
var accounts []service.Account
|
accounts := make([]*service.Account, 0)
|
||||||
|
|
||||||
if len(req.AccountIDs) == 0 {
|
if len(req.AccountIDs) == 0 {
|
||||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||||
@@ -1077,84 +1065,87 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
|||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, acc := range allAccounts {
|
for i := range allAccounts {
|
||||||
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" {
|
acc := &allAccounts[i]
|
||||||
|
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||||
|
if oauthType == "google_one" {
|
||||||
accounts = append(accounts, acc)
|
accounts = append(accounts, acc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, id := range req.AccountIDs {
|
fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
|
||||||
acc, err := h.adminService.GetAccount(ctx, id)
|
if err != nil {
|
||||||
if err != nil {
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, acc := range fetched {
|
||||||
|
if acc == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" {
|
if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
|
||||||
accounts = append(accounts, *acc)
|
continue
|
||||||
}
|
}
|
||||||
|
oauthType, _ := acc.Credentials["oauth_type"].(string)
|
||||||
|
if oauthType != "google_one" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accounts = append(accounts, acc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
total := len(accounts)
|
const maxConcurrency = 10
|
||||||
success := 0
|
g, gctx := errgroup.WithContext(ctx)
|
||||||
failed := 0
|
g.SetLimit(maxConcurrency)
|
||||||
errors := []gin.H{}
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
results := gin.H{
|
||||||
|
"total": len(accounts),
|
||||||
|
"success": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"errors": []gin.H{},
|
||||||
|
}
|
||||||
|
|
||||||
for _, account := range accounts {
|
for _, account := range accounts {
|
||||||
accessToken, ok := account.Credentials["access_token"].(string)
|
acc := account // 闭包捕获
|
||||||
if !ok || accessToken == "" {
|
g.Go(func() error {
|
||||||
failed++
|
_, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
|
||||||
errors = append(errors, gin.H{
|
if err != nil {
|
||||||
"account_id": account.ID,
|
mu.Lock()
|
||||||
"error": "missing access_token",
|
results["failed"] = results["failed"].(int) + 1
|
||||||
|
results["errors"] = append(results["errors"].([]gin.H), gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
|
||||||
|
Credentials: creds,
|
||||||
|
Extra: extra,
|
||||||
})
|
})
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var proxyURL string
|
mu.Lock()
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if updateErr != nil {
|
||||||
proxyURL = account.Proxy.URL()
|
results["failed"] = results["failed"].(int) + 1
|
||||||
}
|
results["errors"] = append(results["errors"].([]gin.H), gin.H{
|
||||||
|
"account_id": acc.ID,
|
||||||
|
"error": updateErr.Error(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
results["success"] = results["success"].(int) + 1
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(ctx, accessToken, proxyURL)
|
return nil
|
||||||
if err != nil {
|
|
||||||
failed++
|
|
||||||
errors = append(errors, gin.H{
|
|
||||||
"account_id": account.ID,
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.Extra == nil {
|
|
||||||
account.Extra = make(map[string]any)
|
|
||||||
}
|
|
||||||
if storageInfo != nil {
|
|
||||||
account.Extra["drive_storage_limit"] = storageInfo.Limit
|
|
||||||
account.Extra["drive_storage_usage"] = storageInfo.Usage
|
|
||||||
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
account.Credentials["tier_id"] = tierID
|
|
||||||
|
|
||||||
_, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
|
|
||||||
Credentials: account.Credentials,
|
|
||||||
Extra: account.Extra,
|
|
||||||
})
|
})
|
||||||
if updateErr != nil {
|
|
||||||
failed++
|
|
||||||
errors = append(errors, gin.H{
|
|
||||||
"account_id": account.ID,
|
|
||||||
"error": updateErr.Error(),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
success++
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
if err := g.Wait(); err != nil {
|
||||||
"total": total,
|
response.ErrorFrom(c, err)
|
||||||
"success": success,
|
return
|
||||||
"failed": failed,
|
}
|
||||||
"errors": errors,
|
|
||||||
})
|
response.Success(c, results)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
@@ -49,13 +51,38 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL
|
|||||||
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
|
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retry logic with exponential backoff for rate limits
|
sleepWithContext := func(d time.Duration) error {
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
maxRetries := 3
|
maxRetries := 3
|
||||||
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||||
|
}
|
||||||
|
|
||||||
resp, err = client.Do(req)
|
resp, err = client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
// Network error retry
|
||||||
|
if attempt < maxRetries-1 {
|
||||||
|
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||||
|
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||||
|
if err := sleepWithContext(backoff + jitter); err != nil {
|
||||||
|
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success
|
// Success
|
||||||
@@ -63,18 +90,34 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rate limit - retry with exponential backoff
|
// Retry 429, 500, 502, 503 with exponential backoff + jitter
|
||||||
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetries-1 {
|
if (resp.StatusCode == http.StatusTooManyRequests ||
|
||||||
|
resp.StatusCode == http.StatusInternalServerError ||
|
||||||
|
resp.StatusCode == http.StatusBadGateway ||
|
||||||
|
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
backoff := time.Duration(1<<uint(attempt)) * time.Second // 1s, 2s, 4s
|
backoff := time.Duration(1<<uint(attempt)) * time.Second
|
||||||
time.Sleep(backoff)
|
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
|
||||||
|
if err := sleepWithContext(backoff + jitter); err != nil {
|
||||||
|
return nil, fmt.Errorf("request cancelled: %w", err)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Other errors - return immediately
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil {
|
||||||
|
return nil, fmt.Errorf("request failed: no response received")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
return nil, fmt.Errorf("drive API error (status %d): %s", resp.StatusCode, string(body))
|
// 记录完整错误
|
||||||
|
fmt.Printf("[DriveClient] API error (status %d): %s\n", resp.StatusCode, string(body))
|
||||||
|
// 只返回通用错误
|
||||||
|
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
@@ -94,10 +137,14 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL
|
|||||||
// Parse limit and usage (handle both string and number formats)
|
// Parse limit and usage (handle both string and number formats)
|
||||||
var limit, usage int64
|
var limit, usage int64
|
||||||
if result.StorageQuota.Limit != "" {
|
if result.StorageQuota.Limit != "" {
|
||||||
_, _ = fmt.Sscanf(result.StorageQuota.Limit, "%d", &limit)
|
if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
|
||||||
|
limit = val
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if result.StorageQuota.Usage != "" {
|
if result.StorageQuota.Usage != "" {
|
||||||
_, _ = fmt.Sscanf(result.StorageQuota.Usage, "%d", &usage)
|
if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
|
||||||
|
usage = val
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DriveStorageInfo{
|
return &DriveStorageInfo{
|
||||||
|
|||||||
19
backend/internal/pkg/geminicli/drive_client_test.go
Normal file
19
backend/internal/pkg/geminicli/drive_client_test.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package geminicli
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDriveStorageInfo(t *testing.T) {
|
||||||
|
// 测试 DriveStorageInfo 结构体
|
||||||
|
info := &DriveStorageInfo{
|
||||||
|
Limit: 100 * 1024 * 1024 * 1024, // 100GB
|
||||||
|
Usage: 50 * 1024 * 1024 * 1024, // 50GB
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Limit != 100*1024*1024*1024 {
|
||||||
|
t.Errorf("Expected limit 100GB, got %d", info.Limit)
|
||||||
|
}
|
||||||
|
if info.Usage != 50*1024*1024*1024 {
|
||||||
|
t.Errorf("Expected usage 50GB, got %d", info.Usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
|
|||||||
return &accounts[0], nil
|
return &accounts[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*service.Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// De-duplicate while preserving order of first occurrence.
|
||||||
|
uniqueIDs := make([]int64, 0, len(ids))
|
||||||
|
seen := make(map[int64]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
if id <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
uniqueIDs = append(uniqueIDs, id)
|
||||||
|
}
|
||||||
|
if len(uniqueIDs) == 0 {
|
||||||
|
return []*service.Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entAccounts, err := r.client.Account.
|
||||||
|
Query().
|
||||||
|
Where(dbaccount.IDIn(uniqueIDs...)).
|
||||||
|
WithProxy().
|
||||||
|
All(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(entAccounts) == 0 {
|
||||||
|
return []*service.Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accountIDs := make([]int64, 0, len(entAccounts))
|
||||||
|
entByID := make(map[int64]*dbent.Account, len(entAccounts))
|
||||||
|
for _, acc := range entAccounts {
|
||||||
|
entByID[acc.ID] = acc
|
||||||
|
accountIDs = append(accountIDs, acc.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
outByID := make(map[int64]*service.Account, len(entAccounts))
|
||||||
|
for _, entAcc := range entAccounts {
|
||||||
|
out := accountEntityToService(entAcc)
|
||||||
|
if out == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer the preloaded proxy edge when available.
|
||||||
|
if entAcc.Edges.Proxy != nil {
|
||||||
|
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
if groups, ok := groupsByAccount[entAcc.ID]; ok {
|
||||||
|
out.Groups = groups
|
||||||
|
}
|
||||||
|
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
|
||||||
|
out.GroupIDs = groupIDs
|
||||||
|
}
|
||||||
|
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
|
||||||
|
out.AccountGroups = ags
|
||||||
|
}
|
||||||
|
outByID[entAcc.ID] = out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve input order (first occurrence), and ignore missing IDs.
|
||||||
|
out := make([]*service.Account, 0, len(uniqueIDs))
|
||||||
|
for _, id := range uniqueIDs {
|
||||||
|
if _, ok := entByID[id]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if acc, ok := outByID[id]; ok && acc != nil {
|
||||||
|
out = append(out, acc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ExistsByID 检查指定 ID 的账号是否存在。
|
// ExistsByID 检查指定 ID 的账号是否存在。
|
||||||
// 相比 GetByID,此方法性能更优,因为:
|
// 相比 GetByID,此方法性能更优,因为:
|
||||||
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
|
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ var (
|
|||||||
type AccountRepository interface {
|
type AccountRepository interface {
|
||||||
Create(ctx context.Context, account *Account) error
|
Create(ctx context.Context, account *Account) error
|
||||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||||
|
// GetByIDs fetches accounts by IDs in a single query.
|
||||||
|
// It should return all accounts found (missing IDs are ignored).
|
||||||
|
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||||
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
|
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
|
||||||
ExistsByID(ctx context.Context, id int64) (bool, error)
|
ExistsByID(ctx context.Context, id int64) (bool, error)
|
||||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ type AdminService interface {
|
|||||||
// Account management
|
// Account management
|
||||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||||
|
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||||
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
|
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
|
||||||
DeleteAccount(ctx context.Context, id int64) error
|
DeleteAccount(ctx context.Context, id int64) error
|
||||||
@@ -611,6 +612,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
|
|||||||
return s.accountRepo.GetByID(ctx, id)
|
return s.accountRepo.GetByID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*Account{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accounts, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
|
|||||||
@@ -26,6 +26,17 @@ const (
|
|||||||
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
GB = 1024 * 1024 * 1024
|
||||||
|
TB = 1024 * GB
|
||||||
|
|
||||||
|
StorageTierUnlimited = 100 * TB // 100TB
|
||||||
|
StorageTierAIPremium = 2 * TB // 2TB
|
||||||
|
StorageTierStandard = 200 * GB // 200GB
|
||||||
|
StorageTierBasic = 100 * GB // 100GB
|
||||||
|
StorageTierFree = 15 * GB // 15GB
|
||||||
|
)
|
||||||
|
|
||||||
type GeminiOAuthService struct {
|
type GeminiOAuthService struct {
|
||||||
sessionStore *geminicli.SessionStore
|
sessionStore *geminicli.SessionStore
|
||||||
proxyRepo ProxyRepository
|
proxyRepo ProxyRepository
|
||||||
@@ -222,31 +233,21 @@ func inferGoogleOneTier(storageBytes int64) string {
|
|||||||
return TierGoogleOneUnknown
|
return TierGoogleOneUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unlimited storage (G Suite legacy)
|
if storageBytes > StorageTierUnlimited {
|
||||||
if storageBytes > 100*1024*1024*1024*1024 { // > 100TB
|
|
||||||
return TierGoogleOneUnlimited
|
return TierGoogleOneUnlimited
|
||||||
}
|
}
|
||||||
|
if storageBytes >= StorageTierAIPremium {
|
||||||
// AI Premium (2TB+)
|
|
||||||
if storageBytes >= 2*1024*1024*1024*1024 { // >= 2TB
|
|
||||||
return TierAIPremium
|
return TierAIPremium
|
||||||
}
|
}
|
||||||
|
if storageBytes >= StorageTierStandard {
|
||||||
// Google One Standard (200GB)
|
|
||||||
if storageBytes >= 200*1024*1024*1024 { // >= 200GB
|
|
||||||
return TierGoogleOneStandard
|
return TierGoogleOneStandard
|
||||||
}
|
}
|
||||||
|
if storageBytes >= StorageTierBasic {
|
||||||
// Google One Basic (100GB)
|
|
||||||
if storageBytes >= 100*1024*1024*1024 { // >= 100GB
|
|
||||||
return TierGoogleOneBasic
|
return TierGoogleOneBasic
|
||||||
}
|
}
|
||||||
|
if storageBytes >= StorageTierFree {
|
||||||
// Free (15GB)
|
|
||||||
if storageBytes >= 15*1024*1024*1024 { // >= 15GB
|
|
||||||
return TierFree
|
return TierFree
|
||||||
}
|
}
|
||||||
|
|
||||||
return TierGoogleOneUnknown
|
return TierGoogleOneUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,6 +271,60 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken
|
|||||||
return tierID, storageInfo, nil
|
return tierID, storageInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
|
||||||
|
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||||
|
ctx context.Context,
|
||||||
|
account *Account,
|
||||||
|
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
|
||||||
|
if account == nil {
|
||||||
|
return "", nil, nil, fmt.Errorf("account is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证账号类型
|
||||||
|
oauthType, ok := account.Credentials["oauth_type"].(string)
|
||||||
|
if !ok || oauthType != "google_one" {
|
||||||
|
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 access_token
|
||||||
|
accessToken, ok := account.Credentials["access_token"].(string)
|
||||||
|
if !ok || accessToken == "" {
|
||||||
|
return "", nil, nil, fmt.Errorf("missing access_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 proxy URL
|
||||||
|
var proxyURL string
|
||||||
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
|
proxyURL = account.Proxy.URL()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 Drive API
|
||||||
|
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 extra 数据(保留原有 extra 字段)
|
||||||
|
extra = make(map[string]any)
|
||||||
|
for k, v := range account.Extra {
|
||||||
|
extra[k] = v
|
||||||
|
}
|
||||||
|
if storageInfo != nil {
|
||||||
|
extra["drive_storage_limit"] = storageInfo.Limit
|
||||||
|
extra["drive_storage_usage"] = storageInfo.Usage
|
||||||
|
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 credentials 数据
|
||||||
|
credentials = make(map[string]any)
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
credentials[k] = v
|
||||||
|
}
|
||||||
|
credentials["tier_id"] = tierID
|
||||||
|
|
||||||
|
return tierID, extra, credentials, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||||
session, ok := s.sessionStore.Get(input.SessionID)
|
session, ok := s.sessionStore.Get(input.SessionID)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
52
backend/internal/service/gemini_oauth_service_test.go
Normal file
52
backend/internal/service/gemini_oauth_service_test.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestInferGoogleOneTier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
storageBytes int64
|
||||||
|
expectedTier string
|
||||||
|
}{
|
||||||
|
{"Negative storage", -1, TierGoogleOneUnknown},
|
||||||
|
{"Zero storage", 0, TierGoogleOneUnknown},
|
||||||
|
|
||||||
|
// Free tier boundary (15GB)
|
||||||
|
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
|
||||||
|
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
|
||||||
|
{"Free tier (15GB)", StorageTierFree, TierFree},
|
||||||
|
|
||||||
|
// Basic tier boundary (100GB)
|
||||||
|
{"Between free and basic", 50 * GB, TierFree},
|
||||||
|
{"Just below basic tier", StorageTierBasic - 1, TierFree},
|
||||||
|
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
|
||||||
|
|
||||||
|
// Standard tier boundary (200GB)
|
||||||
|
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
|
||||||
|
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
|
||||||
|
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
|
||||||
|
|
||||||
|
// AI Premium tier boundary (2TB)
|
||||||
|
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
|
||||||
|
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
|
||||||
|
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
|
||||||
|
|
||||||
|
// Unlimited tier boundary (> 100TB)
|
||||||
|
{"Between premium and unlimited", 50 * TB, TierAIPremium},
|
||||||
|
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
|
||||||
|
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
|
||||||
|
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
|
||||||
|
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := inferGoogleOneTier(tt.storageBytes)
|
||||||
|
if result != tt.expectedTier {
|
||||||
|
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
|
||||||
|
tt.storageBytes, result, tt.expectedTier)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -26,5 +26,5 @@ UPDATE accounts
|
|||||||
SET credentials = credentials - 'tier_id'
|
SET credentials = credentials - 'tier_id'
|
||||||
WHERE platform = 'gemini'
|
WHERE platform = 'gemini'
|
||||||
AND type = 'oauth'
|
AND type = 'oauth'
|
||||||
AND credentials->>'oauth_type' = 'code_assist';
|
AND credentials ? 'tier_id';
|
||||||
-- +goose StatementEnd
|
-- +goose StatementEnd
|
||||||
Reference in New Issue
Block a user