diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 47167cf2..182ea5f8 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -170,6 +170,71 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := &service.Group{ + ID: 101, + Name: "g1", + Status: service.StatusActive, + Platform: service.PlatformAnthropic, + Hydrated: true, + } + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + router := gin.New() + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + + invalidGroup := &service.Group{ + ID: group.ID, + Platform: group.Platform, + Status: group.Status, + } + router.GET("/t", func(c *gin.Context) { + groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) + if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup { + c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup)) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 5520f78e..3c8a5f78 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1145,6 +1145,29 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) require.Equal(t, 1, groupRepo.getByIDLiteCalls) } +func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) { + groupID := int64(42) + invalidGroup := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + } + hydratedGroup := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup) + svc := &GatewayService{} + ctx = svc.withGroupContext(ctx, hydratedGroup) + + got, ok := ctx.Value(ctxkey.Group).(*Group) + require.True(t, ok) + require.Same(t, hydratedGroup, got) +} + func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { ctx := context.Background() groupID := int64(10)