diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index c47e66df..1c0ef8e6 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -844,6 +845,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage +// +// Two modes: +// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage. +// - unrestricted: No key-level limits. Returns subscription or wallet balance info. func (h *GatewayHandler) Usage(c *gin.Context) { apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -857,54 +862,183 @@ func (h *GatewayHandler) Usage(c *gin.Context) { return } + ctx := c.Request.Context() + + // 解析可选的日期范围参数(用于 model_stats 查询) + startTime, endTime := h.parseUsageDateRange(c) + // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 - var usageData gin.H + usageData := h.buildUsageData(ctx, apiKey.ID) + + // Best-effort: 获取模型统计 + var modelStats any if h.usageService != nil { - dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID) - if err == nil && dashStats != nil { - usageData = gin.H{ - "today": gin.H{ - "requests": dashStats.TodayRequests, - "input_tokens": dashStats.TodayInputTokens, - "output_tokens": dashStats.TodayOutputTokens, - "cache_creation_tokens": dashStats.TodayCacheCreationTokens, - "cache_read_tokens": dashStats.TodayCacheReadTokens, - "total_tokens": dashStats.TodayTokens, - "cost": dashStats.TodayCost, - "actual_cost": dashStats.TodayActualCost, - }, - "total": gin.H{ - "requests": dashStats.TotalRequests, - "input_tokens": dashStats.TotalInputTokens, - "output_tokens": dashStats.TotalOutputTokens, - "cache_creation_tokens": dashStats.TotalCacheCreationTokens, - "cache_read_tokens": dashStats.TotalCacheReadTokens, - "total_tokens": dashStats.TotalTokens, - "cost": dashStats.TotalCost, - "actual_cost": dashStats.TotalActualCost, - }, - "average_duration_ms": dashStats.AverageDurationMs, - "rpm": dashStats.Rpm, - "tpm": dashStats.Tpm, + if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 { + modelStats = stats + } + } + + // 判断模式: key 有总额度或速率限制 → quota_limited,否则 → unrestricted + isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits() + + if isQuotaLimited { + h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats) + return + } + + h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats) +} + +// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围 +func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) { + now := timezone.Now() + endTime := now + startTime := now.AddDate(0, 0, -30) + + if s := c.Query("start_date"); s != "" { + if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { + startTime = t + } + } + if s := c.Query("end_date"); s != "" { + if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { + endTime = t.Add(24*time.Hour - time.Second) // end of day + } + } + return startTime, endTime +} + +// buildUsageData 构建 today/total 用量摘要 +func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H { + if h.usageService == nil { + return nil + } + dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID) + if err != nil || dashStats == nil { + return nil + } + return gin.H{ + "today": gin.H{ + "requests": dashStats.TodayRequests, + "input_tokens": dashStats.TodayInputTokens, + "output_tokens": dashStats.TodayOutputTokens, + "cache_creation_tokens": dashStats.TodayCacheCreationTokens, + "cache_read_tokens": dashStats.TodayCacheReadTokens, + "total_tokens": dashStats.TodayTokens, + "cost": dashStats.TodayCost, + "actual_cost": dashStats.TodayActualCost, + }, + "total": gin.H{ + "requests": dashStats.TotalRequests, + "input_tokens": dashStats.TotalInputTokens, + "output_tokens": dashStats.TotalOutputTokens, + "cache_creation_tokens": dashStats.TotalCacheCreationTokens, + "cache_read_tokens": dashStats.TotalCacheReadTokens, + "total_tokens": dashStats.TotalTokens, + "cost": dashStats.TotalCost, + "actual_cost": dashStats.TotalActualCost, + }, + "average_duration_ms": dashStats.AverageDurationMs, + "rpm": dashStats.Rpm, + "tpm": dashStats.Tpm, + } +} + +// usageQuotaLimited 处理 quota_limited 模式的响应 +func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) { + resp := gin.H{ + "mode": "quota_limited", + "isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired, + "status": apiKey.Status, + } + + // 总额度信息 + if apiKey.Quota > 0 { + remaining := apiKey.GetQuotaRemaining() + resp["quota"] = gin.H{ + "limit": apiKey.Quota, + "used": apiKey.QuotaUsed, + "remaining": remaining, + "unit": "USD", + } + resp["remaining"] = remaining + resp["unit"] = "USD" + } + + // 速率限制信息(从 DB 获取实时用量) + if apiKey.HasRateLimits() && h.apiKeyService != nil { + rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID) + if err == nil && rateLimitData != nil { + var rateLimits []gin.H + if apiKey.RateLimit5h > 0 { + used := rateLimitData.Usage5h + rateLimits = append(rateLimits, gin.H{ + "window": "5h", + "limit": apiKey.RateLimit5h, + "used": used, + "remaining": max(0, apiKey.RateLimit5h-used), + "window_start": rateLimitData.Window5hStart, + }) + } + if apiKey.RateLimit1d > 0 { + used := rateLimitData.Usage1d + rateLimits = append(rateLimits, gin.H{ + "window": "1d", + "limit": apiKey.RateLimit1d, + "used": used, + "remaining": max(0, apiKey.RateLimit1d-used), + "window_start": rateLimitData.Window1dStart, + }) + } + if apiKey.RateLimit7d > 0 { + used := rateLimitData.Usage7d + rateLimits = append(rateLimits, gin.H{ + "window": "7d", + "limit": apiKey.RateLimit7d, + "used": used, + "remaining": max(0, apiKey.RateLimit7d-used), + "window_start": rateLimitData.Window7dStart, + }) + } + if len(rateLimits) > 0 { + resp["rate_limits"] = rateLimits } } } - // 订阅模式:返回订阅限额信息 + 用量统计 + // 过期时间 + if apiKey.ExpiresAt != nil { + resp["expires_at"] = apiKey.ExpiresAt + resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry() + } + + if usageData != nil { + resp["usage"] = usageData + } + if modelStats != nil { + resp["model_stats"] = modelStats + } + + c.JSON(http.StatusOK, resp) +} + +// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容) +func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) { + // 订阅模式 if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { - subscription, ok := middleware2.GetSubscriptionFromContext(c) - if !ok { - h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription") - return + resp := gin.H{ + "mode": "unrestricted", + "isValid": true, + "planName": apiKey.Group.Name, + "unit": "USD", } - remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) - resp := gin.H{ - "isValid": true, - "planName": apiKey.Group.Name, - "remaining": remaining, - "unit": "USD", - "subscription": gin.H{ + // 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查) + subscription, ok := middleware2.GetSubscriptionFromContext(c) + if ok { + remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) + resp["remaining"] = remaining + resp["subscription"] = gin.H{ "daily_usage_usd": subscription.DailyUsageUSD, "weekly_usage_usd": subscription.WeeklyUsageUSD, "monthly_usage_usd": subscription.MonthlyUsageUSD, @@ -912,23 +1046,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) { "weekly_limit_usd": apiKey.Group.WeeklyLimitUSD, "monthly_limit_usd": apiKey.Group.MonthlyLimitUSD, "expires_at": subscription.ExpiresAt, - }, + } } + if usageData != nil { resp["usage"] = usageData } + if modelStats != nil { + resp["model_stats"] = modelStats + } c.JSON(http.StatusOK, resp) return } - // 余额模式:返回钱包余额 + 用量统计 - latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + // 余额模式 + latestUser, err := h.userService.GetByID(ctx, subject.UserID) if err != nil { h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") return } resp := gin.H{ + "mode": "unrestricted", "isValid": true, "planName": "钱包余额", "remaining": latestUser.Balance, @@ -938,6 +1077,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) { if usageData != nil { resp["usage"] = usageData } + if modelStats != nil { + resp["model_stats"] = modelStats + } c.JSON(http.StatusOK, resp) } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 5be32095..7f1f7977 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -818,6 +818,11 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } +// GetRateLimitData returns rate limit usage and window state for an API key. +func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + return s.apiKeyRepo.GetRateLimitData(ctx, id) +} + // UpdateRateLimitUsage atomically increments rate limit usage counters in the DB. func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { if cost <= 0 { diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index f21a2855..d64f01e0 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -315,6 +315,15 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star return stats, nil } +// GetAPIKeyModelStats returns per-model usage stats for a specific API Key. +func (s *UsageService) GetAPIKeyModelStats(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, 0, apiKeyID, 0, 0, nil, nil, nil) + if err != nil { + return nil, fmt.Errorf("get api key model stats: %w", err) + } + return stats, nil +} + // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)