Files
xinghuoapi/backend/internal/server/middleware/api_key_auth.go
shaw a728dfe0c6 refactor: 重构 api_key_auth 中间件,用 skipBilling 替代 7 处散落的 isUsageQuery
将中间件职责拆分为鉴权(Authentication)和计费执行(Billing Enforcement)两层:
- 鉴权层(disabled/IP/用户状态)始终执行
- 计费层(过期/配额/订阅/余额)用单一 skipBilling 守卫整块控制

/v1/usage 端点只需鉴权不需计费,skipBilling 仅出现 2 处(订阅加载错误处理 + 计费块守卫),
取代了之前 isUsageQuery 散布在 7 个 if 分支中的控制流。
2026-03-03 20:58:00 +08:00

253 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware {
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
//
// 中间件职责分为两层:
// - 鉴权Authentication验证 Key 有效性、用户状态、IP 限制 —— 始终执行
// - 计费执行Billing Enforcement过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
//
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// ── 1. 提取 API Key ──────────────────────────────────────────
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
return
}
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
apiKeyString = strings.TrimSpace(parts[1])
}
}
// 如果Authorization header中没有尝试从x-api-key header中提取
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-api-key")
}
// 如果x-api-key header中没有尝试从x-goog-api-key header中提取Gemini CLI兼容
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-goog-api-key")
}
// 如果所有header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return
}
// ── 2. 验证 Key 存在 ─────────────────────────────────────────
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// ── 3. 基础鉴权(始终执行) ─────────────────────────────────
// disabled / 未知状态 → 无条件拦截expired 和 quota_exhausted 留给计费阶段)
if !apiKey.IsActive() &&
apiKey.Status != service.StatusAPIKeyExpired &&
apiKey.Status != service.StatusAPIKeyQuotaExhausted {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// ── 4. SimpleMode → early return ─────────────────────────────
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
return
}
// ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
skipBilling := c.Request.URL.Path == "/v1/usage"
var subscription *service.UserSubscription
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
sub, subErr := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if subErr != nil {
if !skipBilling {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// skipBilling: 订阅不存在也放行handler 会返回可用的数据
} else {
subscription = sub
}
}
// ── 6. 计费执行skipBilling 时整块跳过) ────────────────────
if !skipBilling {
// Key 状态检查
switch apiKey.Status {
case service.StatusAPIKeyQuotaExhausted:
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return
case service.StatusAPIKeyExpired:
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
// 运行时过期/配额检查(即使状态是 active也要检查时间和用量
if apiKey.IsExpired() {
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
return
}
if apiKey.IsQuotaExhausted() {
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
return
}
// 订阅模式:验证订阅限额
if subscription != nil {
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if validateErr != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
AbortWithError(c, status, code, validateErr.Error())
return
}
// 窗口维护异步化(不阻塞请求)
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
}
// ── 7. 设置上下文 → Next ─────────────────────────────────────
if subscription != nil {
c.Set(string(ContextKeySubscription), subscription)
}
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
}
}
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.APIKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*service.UserSubscription)
return subscription, ok
}
func setGroupContext(c *gin.Context, group *service.Group) {
if !service.IsGroupContextValid(group) {
return
}
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
return
}
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
c.Request = c.Request.WithContext(ctx)
}