feat(subscription): 有界队列执行维护并改进鉴权解析
This commit is contained in:
@@ -58,8 +58,13 @@ func adminAuth(
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userService) {
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
return
|
||||
}
|
||||
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
|
||||
@@ -35,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
if authHeader != "" {
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
apiKeyString = parts[1]
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
apiKeyString = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
go subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
|
||||
@@ -57,6 +57,57 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
cfg.SubscriptionMaintenance.WorkerCount = 1
|
||||
cfg.SubscriptionMaintenance.QueueSize = 1
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
past := time.Now().Add(-48 * time.Hour)
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
DailyWindowStart: &past,
|
||||
DailyUsageUSD: 0,
|
||||
}
|
||||
maintenanceCalled := make(chan struct{}, 1)
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error {
|
||||
maintenanceCalled <- struct{}{}
|
||||
return nil
|
||||
},
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||
t.Cleanup(subscriptionService.Stop)
|
||||
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
select {
|
||||
case <-maintenanceCalled:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("expected maintenance to be scheduled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
@@ -71,6 +122,20 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "bearer "+apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
|
||||
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
tokenString := strings.TrimSpace(parts[1])
|
||||
if tokenString == "" {
|
||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||
return
|
||||
|
||||
@@ -84,6 +84,28 @@ func TestJWTAuth_ValidToken(t *testing.T) {
|
||||
require.Equal(t, "user", body["role"])
|
||||
}
|
||||
|
||||
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
|
||||
126
backend/internal/server/middleware/misc_coverage_test.go
Normal file
126
backend/internal/server/middleware/misc_coverage_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ClientRequestID())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
v := c.Request.Context().Value(ctxkey.ClientRequestID)
|
||||
require.NotNil(t, v)
|
||||
id, ok := v.(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, id)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestClientRequestID_PreservesExisting(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ClientRequestID())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "keep", id)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(RequestBodyLimit(4))
|
||||
r.POST("/t", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
require.Error(t, err)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ForcePlatform("anthropic"))
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
require.True(t, HasForcePlatform(c))
|
||||
v, ok := GetForcePlatformFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "anthropic", v)
|
||||
|
||||
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
|
||||
require.Equal(t, "anthropic", ctxV)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
|
||||
c := &gin.Context{}
|
||||
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
|
||||
c.Set(string(ContextKeyUserRole), "admin")
|
||||
|
||||
sub, ok := GetAuthSubjectFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(1), sub.UserID)
|
||||
require.Equal(t, 2, sub.Concurrency)
|
||||
|
||||
role, ok := GetUserRoleFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "admin", role)
|
||||
}
|
||||
|
||||
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
|
||||
c := &gin.Context{}
|
||||
|
||||
key := &service.APIKey{ID: 1}
|
||||
c.Set(string(ContextKeyAPIKey), key)
|
||||
gotKey, ok := GetAPIKeyFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(1), gotKey.ID)
|
||||
|
||||
sub := &service.UserSubscription{ID: 2}
|
||||
c.Set(string(ContextKeySubscription), sub)
|
||||
gotSub, ok := GetSubscriptionFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(2), gotSub.ID)
|
||||
}
|
||||
Reference in New Issue
Block a user