feat(subscription): 有界队列执行维护并改进鉴权解析

This commit is contained in:
yangjianbo
2026-02-10 00:37:47 +08:00
parent 2bfb16291f
commit 3fcb0cc37c
13 changed files with 558 additions and 66 deletions

View File

@@ -58,8 +58,13 @@ func adminAuth(
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userService) {
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
token := strings.TrimSpace(parts[1])
if token == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
return
}
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()

View File

@@ -35,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
apiKeyString = strings.TrimSpace(parts[1])
}
}
@@ -166,7 +166,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额

View File

@@ -57,6 +57,57 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
},
}
t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
cfg.SubscriptionMaintenance.WorkerCount = 1
cfg.SubscriptionMaintenance.QueueSize = 1
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
past := time.Now().Add(-48 * time.Hour)
sub := &service.UserSubscription{
ID: 55,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyWindowStart: &past,
DailyUsageUSD: 0,
}
maintenanceCalled := make(chan struct{}, 1)
subscriptionRepo := &stubUserSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error {
maintenanceCalled <- struct{}{}
return nil
},
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
t.Cleanup(subscriptionService.Stop)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
select {
case <-maintenanceCalled:
// ok
case <-time.After(time.Second):
t.Fatalf("expected maintenance to be scheduled")
}
})
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "bearer "+apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)

View File

@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
tokenString := strings.TrimSpace(parts[1])
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return

View File

@@ -84,6 +84,28 @@ func TestJWTAuth_ValidToken(t *testing.T) {
require.Equal(t, "user", body["role"])
}
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
user := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
Concurrency: 5,
TokenVersion: 1,
}
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
token, err := authSvc.GenerateToken(user)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)

View File

@@ -0,0 +1,126 @@
//go:build unit
package middleware
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ClientRequestID())
r.GET("/t", func(c *gin.Context) {
v := c.Request.Context().Value(ctxkey.ClientRequestID)
require.NotNil(t, v)
id, ok := v.(string)
require.True(t, ok)
require.NotEmpty(t, id)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestClientRequestID_PreservesExisting(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ClientRequestID())
r.GET("/t", func(c *gin.Context) {
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
require.True(t, ok)
require.Equal(t, "keep", id)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestBodyLimit(4))
r.POST("/t", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
require.Error(t, err)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ForcePlatform("anthropic"))
r.GET("/t", func(c *gin.Context) {
require.True(t, HasForcePlatform(c))
v, ok := GetForcePlatformFromContext(c)
require.True(t, ok)
require.Equal(t, "anthropic", v)
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
require.Equal(t, "anthropic", ctxV)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
c := &gin.Context{}
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
c.Set(string(ContextKeyUserRole), "admin")
sub, ok := GetAuthSubjectFromContext(c)
require.True(t, ok)
require.Equal(t, int64(1), sub.UserID)
require.Equal(t, 2, sub.Concurrency)
role, ok := GetUserRoleFromContext(c)
require.True(t, ok)
require.Equal(t, "admin", role)
}
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
c := &gin.Context{}
key := &service.APIKey{ID: 1}
c.Set(string(ContextKeyAPIKey), key)
gotKey, ok := GetAPIKeyFromContext(c)
require.True(t, ok)
require.Equal(t, int64(1), gotKey.ID)
sub := &service.UserSubscription{ID: 2}
c.Set(string(ContextKeySubscription), sub)
gotSub, ok := GetSubscriptionFromContext(c)
require.True(t, ok)
require.Equal(t, int64(2), gotSub.ID)
}