Merge branch 'main' into feature/antigravity_auth
This commit is contained in:
@@ -54,15 +54,23 @@ type UsageLogRepository interface {
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
}
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
type usageCache struct {
|
||||
data *UsageInfo
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
type apiUsageCache struct {
|
||||
response *ClaudeUsageResponse
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost)
|
||||
type windowStatsCache struct {
|
||||
stats *WindowStats
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
usageCacheMap = sync.Map{}
|
||||
cacheTTL = 10 * time.Minute
|
||||
apiCacheMap = sync.Map{} // 缓存 API 响应
|
||||
windowStatsCacheMap = sync.Map{} // 缓存窗口统计
|
||||
apiCacheTTL = 10 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
@@ -126,7 +134,7 @@ func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLog
|
||||
}
|
||||
|
||||
// GetUsage 获取账号使用量
|
||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),缓存10分钟
|
||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
// API Key账号: 不支持usage查询
|
||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||
@@ -137,30 +145,34 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
if account.CanGetUsage() {
|
||||
// 检查缓存
|
||||
if cached, ok := usageCacheMap.Load(accountID); ok {
|
||||
cache, ok := cached.(*usageCache)
|
||||
if !ok {
|
||||
usageCacheMap.Delete(accountID)
|
||||
} else if time.Since(cache.timestamp) < cacheTTL {
|
||||
return cache.data, nil
|
||||
var apiResp *ClaudeUsageResponse
|
||||
|
||||
// 1. 检查 API 缓存(10 分钟)
|
||||
if cached, ok := apiCacheMap.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
}
|
||||
}
|
||||
|
||||
// 从API获取数据
|
||||
usage, err := s.fetchOAuthUsage(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 2. 如果没有缓存,从 API 获取
|
||||
if apiResp == nil {
|
||||
apiResp, err = s.fetchOAuthUsageRaw(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 缓存 API 响应
|
||||
apiCacheMap.Store(accountID, &apiUsageCache{
|
||||
response: apiResp,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// 添加5h窗口统计数据
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
// 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds)
|
||||
now := time.Now()
|
||||
usage := s.buildUsageInfo(apiResp, &now)
|
||||
|
||||
// 缓存结果
|
||||
usageCacheMap.Store(accountID, &usageCache{
|
||||
data: usage,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
@@ -177,31 +189,54 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
// addWindowStats 为usage数据添加窗口期统计
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
if usage.FiveHour == nil {
|
||||
// 修复:即使 FiveHour 为 nil,也要尝试获取统计数据
|
||||
// 因为 SevenDay/SevenDaySonnet 可能需要
|
||||
if usage.FiveHour == nil && usage.SevenDay == nil && usage.SevenDaySonnet == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用session_window_start作为统计起始时间
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
// 如果没有窗口信息,使用5小时前作为默认
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
// 检查窗口统计缓存(1 分钟)
|
||||
var windowStats *WindowStats
|
||||
if cached, ok := windowStatsCacheMap.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
|
||||
windowStats = cache.stats
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
// 如果没有缓存,从数据库查询
|
||||
if windowStats == nil {
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
windowStats = &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
}
|
||||
|
||||
// 缓存窗口统计(1 分钟)
|
||||
windowStatsCacheMap.Store(account.ID, &windowStatsCache{
|
||||
stats: windowStats,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
usage.FiveHour.WindowStats = &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
// 为 FiveHour 添加 WindowStats(5h 窗口统计)
|
||||
if usage.FiveHour != nil {
|
||||
usage.FiveHour.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,8 +262,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
@@ -239,13 +274,7 @@ func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Acco
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return s.buildUsageInfo(usageResp, &now), nil
|
||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
@@ -270,20 +299,16 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
// 5小时窗口
|
||||
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
}
|
||||
if resp.FiveHour.ResetsAt != "" {
|
||||
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
ResetsAt: &fiveHourReset,
|
||||
RemainingSeconds: int(time.Until(fiveHourReset).Seconds()),
|
||||
}
|
||||
info.FiveHour.ResetsAt = &fiveHourReset
|
||||
info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds())
|
||||
} else {
|
||||
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
|
||||
// 即使解析失败也返回utilization
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(input.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil {
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
if g.Name == defaultGroupName {
|
||||
groupIDs = []int64{g.ID}
|
||||
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
)
|
||||
|
||||
@@ -32,14 +33,16 @@ type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
|
||||
@@ -357,7 +357,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:忽略 groupID,查询所有可用账号
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
@@ -1226,6 +1229,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
// 简易模式:忽略分组限制,查询所有可用账号
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
@@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
|
||||
@@ -164,6 +164,14 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConcurrency 更新用户并发数(管理员功能)
|
||||
func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concurrency int) error {
|
||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
||||
return fmt.Errorf("update concurrency: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新用户状态(管理员功能)
|
||||
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
|
||||
Reference in New Issue
Block a user