Source code for aiida.backends.tests.work.daemon

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
from aiida.backends.testbase import AiidaTestCase
import tempfile
from shutil import rmtree

from plum.wait_ons import checkpoint

from aiida.work.persistence import Persistence
from aiida.orm.data.base import get_true_node
import aiida.work.daemon as daemon
from aiida.work.process import Process
from aiida.work.process_registry import ProcessRegistry
from aiida.work.run import submit
from aiida.common.lang import override
from aiida.orm import load_node
import aiida.work.util as util
from aiida.work.test_utils import DummyProcess, ExceptionProcess


[docs]class ProcessEventsTester(Process): EVENTS = ["create", "run", "continue_", "finish", "emitted", "stop", "destroy", ]
[docs] @classmethod def define(cls, spec): super(ProcessEventsTester, cls).define(spec) for label in ["create", "run", "wait", "continue_", "finish", "emitted", "stop", "destroy"]: spec.optional_output(label)
[docs] def __init__(self): super(ProcessEventsTester, self).__init__() self._emitted = False
[docs] @override def on_create(self, pid, inputs, saved_instance_state): super(ProcessEventsTester, self).on_create( pid, inputs, saved_instance_state) self.out("create", get_true_node())
[docs] @override def on_run(self): super(ProcessEventsTester, self).on_run() self.out("run", get_true_node())
[docs] @override def _on_output_emitted(self, output_port, value, dynamic): super(ProcessEventsTester, self)._on_output_emitted( output_port, value, dynamic) if not self._emitted: self._emitted = True self.out("emitted", get_true_node())
[docs] @override def on_wait(self, wait_on): super(ProcessEventsTester, self).on_wait(wait_on) self.out("wait", get_true_node())
[docs] @override def on_continue(self, wait_on): super(ProcessEventsTester, self).on_continue(wait_on) self.out("continue_", get_true_node())
[docs] @override def on_finish(self): super(ProcessEventsTester, self).on_finish() self.out("finish", get_true_node())
[docs] @override def on_stop(self): super(ProcessEventsTester, self).on_stop() self.out("stop", get_true_node())
[docs] @override def on_destroy(self): super(ProcessEventsTester, self).on_destroy() self.out("destroy", get_true_node())
[docs] @override def _run(self): return checkpoint(self.finish)
[docs] def finish(self, wait_on): pass
[docs]class FailCreateFromSavedStateProcess(DummyProcess): """ This class emulates a failure that occurs when loading the process from a saved state. """
[docs] @override def on_create(self, pid, inputs, saved_instance_state): super(FailCreateFromSavedStateProcess, self).on_create( pid, inputs, saved_instance_state) if saved_instance_state is not None: raise RuntimeError()
[docs]class TestDaemon(AiidaTestCase):
[docs] def setUp(self): self.assertEquals(len(util.ProcessStack.stack()), 0) self.storedir = tempfile.mkdtemp() self.storage = Persistence.create_from_basedir(self.storedir)
[docs] def tearDown(self): self.assertEquals(len(util.ProcessStack.stack()), 0) rmtree(self.storedir)
[docs] def test_submit(self): # This call should create an entry in the database with a PK rinfo = submit(DummyProcess) self.assertIsNotNone(rinfo) self.assertIsNotNone(load_node(pk=rinfo.pid))
[docs] def test_tick(self): registry = ProcessRegistry() rinfo = submit(ProcessEventsTester, _jobs_store=self.storage) # Tick the engine a number of times or until there is no more work i = 0 while daemon.tick_workflow_engine(self.storage, print_exceptions=False): self.assertLess(i, 10, "Engine not done after 10 ticks") i += 1 self.assertTrue(registry.has_finished(rinfo.pid))
[docs] def test_multiple_processes(self): submit(DummyProcess, _jobs_store=self.storage) submit(ExceptionProcess, _jobs_store=self.storage) submit(ExceptionProcess, _jobs_store=self.storage) submit(DummyProcess, _jobs_store=self.storage) self.assertFalse(daemon.tick_workflow_engine(self.storage, print_exceptions=False))
[docs] def test_create_fail(self): registry = ProcessRegistry() dp_rinfo = submit(DummyProcess, _jobs_store=self.storage) fail_rinfo = submit(FailCreateFromSavedStateProcess, _jobs_store=self.storage) # Tick the engine a number of times or until there is no more work i = 0 while daemon.tick_workflow_engine(self.storage, print_exceptions=False): self.assertLess(i, 10, "Engine not done after 10 ticks") i += 1 self.assertTrue(registry.has_finished(dp_rinfo.pid)) self.assertTrue(registry.has_finished(fail_rinfo.pid))