From 435f6938928fe86782bb92938034347ea94e14f0 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 09:27:47 +0800 Subject: [PATCH] =?UTF-8?q?test(=E5=88=86=E7=BB=84):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=97=A0=E6=95=88=E4=B8=8A=E4=B8=8B=E6=96=87=E8=A6=86=E7=9B=96?= =?UTF-8?q?=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 补充 GatewayService 与 APIKey 中间件对无效 ctxkey.Group 的覆盖行为测试 测试: make test-backend --- .../server/middleware/api_key_auth_test.go | 65 +++++++++++++++++++ .../service/gateway_multiplatform_test.go | 23 +++++++ 2 files changed, 88 insertions(+) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 47167cf2..182ea5f8 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -170,6 +170,71 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := &service.Group{ + ID: 101, + Name: "g1", + Status: service.StatusActive, + Platform: service.PlatformAnthropic, + Hydrated: true, + } + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + router := gin.New() + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + + invalidGroup := &service.Group{ + ID: group.ID, + Platform: group.Platform, + Status: group.Status, + } + router.GET("/t", func(c *gin.Context) { + groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) + if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup { + c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup)) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 5520f78e..3c8a5f78 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1145,6 +1145,29 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) require.Equal(t, 1, groupRepo.getByIDLiteCalls) } +func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) { + groupID := int64(42) + invalidGroup := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + } + hydratedGroup := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup) + svc := &GatewayService{} + ctx = svc.withGroupContext(ctx, hydratedGroup) + + got, ok := ctx.Value(ctxkey.Group).(*Group) + require.True(t, ok) + require.Same(t, hydratedGroup, got) +} + func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { ctx := context.Background() groupID := int64(10)