Source code for aiida.backends.tests.engine.test_transport

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

from six.moves import range
from tornado.gen import coroutine, Return

from aiida.backends.testbase import AiidaTestCase
from aiida.engine.transports import TransportQueue
from aiida import orm


[docs]class TestTransportQueue(AiidaTestCase): """ Tests for the transport queue """
[docs] def setUp(self, *args, **kwargs): """ Set up a simple authinfo and for later use """ super(TestTransportQueue, self).setUp(*args, **kwargs) self.authinfo = orm.AuthInfo(computer=self.computer, user=orm.User.objects.get_default()).store()
[docs] def tearDown(self, *args, **kwargs): orm.AuthInfo.objects.delete(self.authinfo.id) super(TestTransportQueue, self).tearDown(*args, **kwargs)
[docs] def test_simple_request(self): """ Test a simple transport request """ queue = TransportQueue() loop = queue.loop() @coroutine def test(): trans = None with queue.request_transport(self.authinfo) as request: trans = yield request self.assertTrue(trans.is_open) self.assertFalse(trans.is_open) loop.run_sync(lambda: test())
[docs] def test_get_transport_nested(self): """ Test nesting calls to get the same transport """ transport_queue = TransportQueue() loop = transport_queue.loop() @coroutine def nested(queue, authinfo): with queue.request_transport(authinfo) as request1: trans1 = yield request1 self.assertTrue(trans1.is_open) with queue.request_transport(authinfo) as request2: trans2 = yield request2 self.assertIs(trans1, trans2) self.assertTrue(trans2.is_open) loop.run_sync(lambda: nested(transport_queue, self.authinfo))
[docs] def test_get_transport_interleaved(self): """ Test interleaved calls to get the same transport """ transport_queue = TransportQueue() loop = transport_queue.loop() @coroutine def interleaved(authinfo): with transport_queue.request_transport(authinfo) as trans_future: yield trans_future loop.run_sync(lambda: [interleaved(self.authinfo), interleaved(self.authinfo)])
[docs] def test_return_from_context(self): """ Test raising a Return from coroutine context """ queue = TransportQueue() loop = queue.loop() @coroutine def test(): with queue.request_transport(self.authinfo) as request: trans = yield request raise Return(trans.is_open) retval = loop.run_sync(lambda: test()) self.assertTrue(retval)
[docs] def test_open_fail(self): """ Test that if opening fails """ queue = TransportQueue() loop = queue.loop() @coroutine def test(): with queue.request_transport(self.authinfo) as request: yield request def broken_open(trans): raise RuntimeError("Could not open transport") original = None try: # Let's put in a broken open method original = self.authinfo.get_transport().__class__.open self.authinfo.get_transport().__class__.open = broken_open with self.assertRaises(RuntimeError): loop.run_sync(lambda: test()) finally: self.authinfo.get_transport().__class__.open = original
[docs] def test_safe_interval(self): """Verify that the safe interval for a given in transport is respected by the transport queue.""" # Temporarily set the safe open interval for the default transport to a finite value transport_class = self.authinfo.get_transport().__class__ original_interval = transport_class._DEFAULT_SAFE_OPEN_INTERVAL try: transport_class._DEFAULT_SAFE_OPEN_INTERVAL = 0.25 import time queue = TransportQueue() loop = queue.loop() time_start = time.time() @coroutine def test(iteration): trans = None with queue.request_transport(self.authinfo) as request: trans = yield request time_current = time.time() time_elapsed = time_current - time_start time_minimum = trans.get_safe_open_interval() * (iteration + 1) self.assertTrue(time_elapsed > time_minimum, 'transport safe interval was violated') for i in range(5): loop.run_sync(lambda: test(i)) finally: transport_class._DEFAULT_SAFE_OPEN_INTERVAL = original_interval