test(gemini): 添加 Drive API 和 OAuth 服务单元测试

- 新增 drive_client_test.go:Drive API 客户端单元测试
- 新增 gemini_oauth_service_test.go:OAuth 服务单元测试
- 重构 account_handler.go:改进 RefreshTier API 实现
- 优化 drive_client.go:增强错误处理和重试逻辑
- 完善 repository 和 service 层:支持批量 tier 刷新
- 更新迁移文件编号:017 -> 024(避免冲突)
This commit is contained in:
IanShaw027
2026-01-01 15:07:16 +08:00
parent 34bbfb5dd2
commit 48764e15a5
9 changed files with 383 additions and 118 deletions

View File

@@ -3,7 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings" "strings"
"time" "sync"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
) )
// OAuthHandler handles OAuth-related operations for accounts // OAuthHandler handles OAuth-related operations for accounts
@@ -1000,47 +1001,33 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
return return
} }
account, err := h.adminService.GetAccount(c.Request.Context(), accountID) ctx := c.Request.Context()
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil { if err != nil {
response.NotFound(c, "Account not found") response.NotFound(c, "Account not found")
return return
} }
if account.Credentials == nil || account.Credentials["oauth_type"] != "google_one" { if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth {
response.BadRequest(c, "Account is not a google_one OAuth account") response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh")
return return
} }
accessToken, ok := account.Credentials["access_token"].(string) oauthType, _ := account.Credentials["oauth_type"].(string)
if !ok || accessToken == "" { if oauthType != "google_one" {
response.BadRequest(c, "Missing access_token in credentials") response.BadRequest(c, "Only google_one OAuth accounts support tier refresh")
return return
} }
var proxyURL string tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account)
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(c.Request.Context(), accessToken, proxyURL)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if account.Extra == nil { _, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{
account.Extra = make(map[string]any) Credentials: creds,
} Extra: extra,
if storageInfo != nil {
account.Extra["drive_storage_limit"] = storageInfo.Limit
account.Extra["drive_storage_usage"] = storageInfo.Usage
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
}
account.Credentials["tier_id"] = tierID
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: account.Credentials,
Extra: account.Extra,
}) })
if updateErr != nil { if updateErr != nil {
response.ErrorFrom(c, updateErr) response.ErrorFrom(c, updateErr)
@@ -1049,9 +1036,10 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) {
response.Success(c, gin.H{ response.Success(c, gin.H{
"tier_id": tierID, "tier_id": tierID,
"drive_storage_limit": account.Extra["drive_storage_limit"], "storage_info": extra,
"drive_storage_usage": account.Extra["drive_storage_usage"], "drive_storage_limit": extra["drive_storage_limit"],
"updated_at": account.Extra["drive_tier_updated_at"], "drive_storage_usage": extra["drive_storage_usage"],
"updated_at": extra["drive_tier_updated_at"],
}) })
} }
@@ -1069,7 +1057,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
} }
ctx := c.Request.Context() ctx := c.Request.Context()
var accounts []service.Account accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 { if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
@@ -1077,84 +1065,87 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
for _, acc := range allAccounts { for i := range allAccounts {
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" { acc := &allAccounts[i]
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType == "google_one" {
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
} }
} else { } else {
for _, id := range req.AccountIDs { fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs)
acc, err := h.adminService.GetAccount(ctx, id) if err != nil {
if err != nil { response.ErrorFrom(c, err)
return
}
for _, acc := range fetched {
if acc == nil {
continue continue
} }
if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" { if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth {
accounts = append(accounts, *acc) continue
} }
oauthType, _ := acc.Credentials["oauth_type"].(string)
if oauthType != "google_one" {
continue
}
accounts = append(accounts, acc)
} }
} }
total := len(accounts) const maxConcurrency = 10
success := 0 g, gctx := errgroup.WithContext(ctx)
failed := 0 g.SetLimit(maxConcurrency)
errors := []gin.H{}
var mu sync.Mutex
results := gin.H{
"total": len(accounts),
"success": 0,
"failed": 0,
"errors": []gin.H{},
}
for _, account := range accounts { for _, account := range accounts {
accessToken, ok := account.Credentials["access_token"].(string) acc := account // 闭包捕获
if !ok || accessToken == "" { g.Go(func() error {
failed++ _, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc)
errors = append(errors, gin.H{ if err != nil {
"account_id": account.ID, mu.Lock()
"error": "missing access_token", results["failed"] = results["failed"].(int) + 1
results["errors"] = append(results["errors"].([]gin.H), gin.H{
"account_id": acc.ID,
"error": err.Error(),
})
mu.Unlock()
return nil
}
_, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{
Credentials: creds,
Extra: extra,
}) })
continue
}
var proxyURL string mu.Lock()
if account.ProxyID != nil && account.Proxy != nil { if updateErr != nil {
proxyURL = account.Proxy.URL() results["failed"] = results["failed"].(int) + 1
} results["errors"] = append(results["errors"].([]gin.H), gin.H{
"account_id": acc.ID,
"error": updateErr.Error(),
})
} else {
results["success"] = results["success"].(int) + 1
}
mu.Unlock()
tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(ctx, accessToken, proxyURL) return nil
if err != nil {
failed++
errors = append(errors, gin.H{
"account_id": account.ID,
"error": err.Error(),
})
continue
}
if account.Extra == nil {
account.Extra = make(map[string]any)
}
if storageInfo != nil {
account.Extra["drive_storage_limit"] = storageInfo.Limit
account.Extra["drive_storage_usage"] = storageInfo.Usage
account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339)
}
account.Credentials["tier_id"] = tierID
_, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{
Credentials: account.Credentials,
Extra: account.Extra,
}) })
if updateErr != nil {
failed++
errors = append(errors, gin.H{
"account_id": account.ID,
"error": updateErr.Error(),
})
continue
}
success++
} }
response.Success(c, gin.H{ if err := g.Wait(); err != nil {
"total": total, response.ErrorFrom(c, err)
"success": success, return
"failed": failed, }
"errors": errors,
}) response.Success(c, results)
} }

View File

@@ -5,7 +5,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
@@ -49,13 +51,38 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL
return nil, fmt.Errorf("failed to create HTTP client: %w", err) return nil, fmt.Errorf("failed to create HTTP client: %w", err)
} }
// Retry logic with exponential backoff for rate limits sleepWithContext := func(d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// Retry logic with exponential backoff (+ jitter) for rate limits and transient failures
var resp *http.Response var resp *http.Response
maxRetries := 3 maxRetries := 3
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := 0; attempt < maxRetries; attempt++ {
if ctx.Err() != nil {
return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
}
resp, err = client.Do(req) resp, err = client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err) // Network error retry
if attempt < maxRetries-1 {
backoff := time.Duration(1<<uint(attempt)) * time.Second
jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue
}
return nil, fmt.Errorf("network error after %d attempts: %w", maxRetries, err)
} }
// Success // Success
@@ -63,18 +90,34 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL
break break
} }
// Rate limit - retry with exponential backoff // Retry 429, 500, 502, 503 with exponential backoff + jitter
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetries-1 { if (resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusInternalServerError ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable) && attempt < maxRetries-1 {
_ = resp.Body.Close() _ = resp.Body.Close()
backoff := time.Duration(1<<uint(attempt)) * time.Second // 1s, 2s, 4s backoff := time.Duration(1<<uint(attempt)) * time.Second
time.Sleep(backoff) jitter := time.Duration(rng.Intn(1000)) * time.Millisecond
if err := sleepWithContext(backoff + jitter); err != nil {
return nil, fmt.Errorf("request cancelled: %w", err)
}
continue continue
} }
// Other errors - return immediately break
}
if resp == nil {
return nil, fmt.Errorf("request failed: no response received")
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close() _ = resp.Body.Close()
return nil, fmt.Errorf("drive API error (status %d): %s", resp.StatusCode, string(body)) // 记录完整错误
fmt.Printf("[DriveClient] API error (status %d): %s\n", resp.StatusCode, string(body))
// 只返回通用错误
return nil, fmt.Errorf("drive API error: status %d", resp.StatusCode)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
@@ -94,10 +137,14 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL
// Parse limit and usage (handle both string and number formats) // Parse limit and usage (handle both string and number formats)
var limit, usage int64 var limit, usage int64
if result.StorageQuota.Limit != "" { if result.StorageQuota.Limit != "" {
_, _ = fmt.Sscanf(result.StorageQuota.Limit, "%d", &limit) if val, err := strconv.ParseInt(result.StorageQuota.Limit, 10, 64); err == nil {
limit = val
}
} }
if result.StorageQuota.Usage != "" { if result.StorageQuota.Usage != "" {
_, _ = fmt.Sscanf(result.StorageQuota.Usage, "%d", &usage) if val, err := strconv.ParseInt(result.StorageQuota.Usage, 10, 64); err == nil {
usage = val
}
} }
return &DriveStorageInfo{ return &DriveStorageInfo{

View File

@@ -0,0 +1,19 @@
package geminicli
import "testing"
func TestDriveStorageInfo(t *testing.T) {
// 测试 DriveStorageInfo 结构体
info := &DriveStorageInfo{
Limit: 100 * 1024 * 1024 * 1024, // 100GB
Usage: 50 * 1024 * 1024 * 1024, // 50GB
}
if info.Limit != 100*1024*1024*1024 {
t.Errorf("Expected limit 100GB, got %d", info.Limit)
}
if info.Usage != 50*1024*1024*1024 {
t.Errorf("Expected usage 50GB, got %d", info.Usage)
}
}

View File

@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
return &accounts[0], nil return &accounts[0], nil
} }
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
if len(ids) == 0 {
return []*service.Account{}, nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return []*service.Account{}, nil
}
entAccounts, err := r.client.Account.
Query().
Where(dbaccount.IDIn(uniqueIDs...)).
WithProxy().
All(ctx)
if err != nil {
return nil, err
}
if len(entAccounts) == 0 {
return []*service.Account{}, nil
}
accountIDs := make([]int64, 0, len(entAccounts))
entByID := make(map[int64]*dbent.Account, len(entAccounts))
for _, acc := range entAccounts {
entByID[acc.ID] = acc
accountIDs = append(accountIDs, acc.ID)
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outByID := make(map[int64]*service.Account, len(entAccounts))
for _, entAcc := range entAccounts {
out := accountEntityToService(entAcc)
if out == nil {
continue
}
// Prefer the preloaded proxy edge when available.
if entAcc.Edges.Proxy != nil {
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
}
if groups, ok := groupsByAccount[entAcc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
outByID[entAcc.ID] = out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out := make([]*service.Account, 0, len(uniqueIDs))
for _, id := range uniqueIDs {
if _, ok := entByID[id]; !ok {
continue
}
if acc, ok := outByID[id]; ok && acc != nil {
out = append(out, acc)
}
}
return out, nil
}
// ExistsByID 检查指定 ID 的账号是否存在。 // ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID此方法性能更优因为 // 相比 GetByID此方法性能更优因为
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值 // - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值

View File

@@ -17,6 +17,9 @@ var (
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *Account) error Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error) GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查 // ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error) ExistsByID(ctx context.Context, id int64) (bool, error)
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.

View File

@@ -35,6 +35,7 @@ type AdminService interface {
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
@@ -611,6 +612,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return s.accountRepo.GetByID(ctx, id) return s.accountRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{ account := &Account{
Name: input.Name, Name: input.Name,

View File

@@ -26,6 +26,17 @@ const (
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED" TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
) )
const (
GB = 1024 * 1024 * 1024
TB = 1024 * GB
StorageTierUnlimited = 100 * TB // 100TB
StorageTierAIPremium = 2 * TB // 2TB
StorageTierStandard = 200 * GB // 200GB
StorageTierBasic = 100 * GB // 100GB
StorageTierFree = 15 * GB // 15GB
)
type GeminiOAuthService struct { type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository proxyRepo ProxyRepository
@@ -222,31 +233,21 @@ func inferGoogleOneTier(storageBytes int64) string {
return TierGoogleOneUnknown return TierGoogleOneUnknown
} }
// Unlimited storage (G Suite legacy) if storageBytes > StorageTierUnlimited {
if storageBytes > 100*1024*1024*1024*1024 { // > 100TB
return TierGoogleOneUnlimited return TierGoogleOneUnlimited
} }
if storageBytes >= StorageTierAIPremium {
// AI Premium (2TB+)
if storageBytes >= 2*1024*1024*1024*1024 { // >= 2TB
return TierAIPremium return TierAIPremium
} }
if storageBytes >= StorageTierStandard {
// Google One Standard (200GB)
if storageBytes >= 200*1024*1024*1024 { // >= 200GB
return TierGoogleOneStandard return TierGoogleOneStandard
} }
if storageBytes >= StorageTierBasic {
// Google One Basic (100GB)
if storageBytes >= 100*1024*1024*1024 { // >= 100GB
return TierGoogleOneBasic return TierGoogleOneBasic
} }
if storageBytes >= StorageTierFree {
// Free (15GB)
if storageBytes >= 15*1024*1024*1024 { // >= 15GB
return TierFree return TierFree
} }
return TierGoogleOneUnknown return TierGoogleOneUnknown
} }
@@ -270,6 +271,60 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken
return tierID, storageInfo, nil return tierID, storageInfo, nil
} }
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
ctx context.Context,
account *Account,
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
if account == nil {
return "", nil, nil, fmt.Errorf("account is nil")
}
// 验证账号类型
oauthType, ok := account.Credentials["oauth_type"].(string)
if !ok || oauthType != "google_one" {
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
}
// 获取 access_token
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
return "", nil, nil, fmt.Errorf("missing access_token")
}
// 获取 proxy URL
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 调用 Drive API
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
if err != nil {
return "", nil, nil, err
}
// 构建 extra 数据(保留原有 extra 字段)
extra = make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if storageInfo != nil {
extra["drive_storage_limit"] = storageInfo.Limit
extra["drive_storage_usage"] = storageInfo.Usage
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
}
// 构建 credentials 数据
credentials = make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["tier_id"] = tierID
return tierID, extra, credentials, nil
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID) session, ok := s.sessionStore.Get(input.SessionID)
if !ok { if !ok {

View File

@@ -0,0 +1,52 @@
package service
import "testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
}
})
}
}

View File

@@ -26,5 +26,5 @@ UPDATE accounts
SET credentials = credentials - 'tier_id' SET credentials = credentials - 'tier_id'
WHERE platform = 'gemini' WHERE platform = 'gemini'
AND type = 'oauth' AND type = 'oauth'
AND credentials->>'oauth_type' = 'code_assist'; AND credentials ? 'tier_id';
-- +goose StatementEnd -- +goose StatementEnd