refactor: 调整项目结构为单向依赖
This commit is contained in:
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -47,7 +46,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -82,8 +81,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(user.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -92,10 +91,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("User concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
@@ -106,7 +105,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. 【新增】Wait后二次检查余额/订阅
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
log.Printf("Billing eligibility check failed after wait: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
|
||||
return
|
||||
@@ -133,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
@@ -158,7 +157,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: user,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
@@ -198,7 +197,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -223,7 +222,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 余额模式:返回钱包余额
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
|
||||
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
|
||||
return
|
||||
@@ -241,7 +240,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
// 逻辑:
|
||||
// 1. 如果日/周/月任一限额达到100%,返回0
|
||||
// 2. 否则返回所有已配置周期中剩余额度的最小值
|
||||
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
|
||||
func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
|
||||
var remainingValues []float64
|
||||
|
||||
// 检查日限额
|
||||
@@ -334,7 +333,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := middleware2.GetUserFromContext(c)
|
||||
_, ok = middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
@@ -366,7 +365,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
|
||||
// 校验 billing eligibility(订阅/余额)
|
||||
// 【注意】不计算并发,但需要校验订阅/余额
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user