Source code for aiida.backends.tests.engine.test_utils

# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import unittest
from tornado.ioloop import IOLoop
from tornado.gen import coroutine

from aiida import orm
from aiida.backends.testbase import AiidaTestCase
from aiida.engine.utils import exponential_backoff_retry, RefObjectStore

ITERATION = 0
MAX_ITERATIONS = 3


[docs]class TestExponentialBackoffRetry(AiidaTestCase): """Tests for the exponential backoff retry coroutine."""
[docs] @classmethod def setUpClass(cls, *args, **kwargs): """Set up a simple authinfo and for later use.""" super(TestExponentialBackoffRetry, cls).setUpClass(*args, **kwargs) cls.authinfo = orm.AuthInfo(computer=cls.computer, user=orm.User.objects.get_default()) cls.authinfo.store()
[docs] def test_exponential_backoff_success(self): """Test that exponential backoff will successfully catch exceptions as long as max_attempts is not exceeded.""" ITERATION = 0 loop = IOLoop() @coroutine def coro(): """A function that will raise RuntimeError as long as ITERATION is smaller than MAX_ITERATIONS.""" global ITERATION ITERATION += 1 if ITERATION < MAX_ITERATIONS: raise RuntimeError max_attempts = MAX_ITERATIONS + 1 loop.run_sync(lambda: exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts))
[docs] def test_exponential_backoff_max_attempts_exceeded(self): """Test that exponential backoff will finally raise if max_attempts is exceeded""" ITERATION = 0 loop = IOLoop() @coroutine def coro(): """A function that will raise RuntimeError as long as ITERATION is smaller than MAX_ITERATIONS.""" global ITERATION ITERATION += 1 if ITERATION < MAX_ITERATIONS: raise RuntimeError max_attempts = MAX_ITERATIONS - 1 with self.assertRaises(RuntimeError): loop.run_sync(lambda: exponential_backoff_retry(coro, initial_interval=0.1, max_attempts=max_attempts))
[docs]class RefObjectsStore(unittest.TestCase):
[docs] def test_simple(self): """ Test the reference counting works """ IDENTIFIER = 'a' OBJECT = 'my string' obj_store = RefObjectStore() with obj_store.get(IDENTIFIER, lambda: OBJECT) as obj: # Make sure we got back the same object self.assertIs(OBJECT, obj) # Now check that the reference has the correct information ref = obj_store._objects['a'] self.assertEqual(OBJECT, ref._obj) self.assertEqual(1, ref.count) # Now request the object again with obj_store.get(IDENTIFIER) as obj2: # ...and check the reference has had it's count upped self.assertEqual(OBJECT, obj2) self.assertEqual(2, ref.count) # Now it should have been reduced self.assertEqual(1, ref.count) # Finally the store should be empty (there are no more references) self.assertEqual(0, len(obj_store._objects))
[docs] def test_get_no_constructor(self): """ Test that trying to get an object that does exists and providing no means to construct it fails """ obj_store = RefObjectStore() with self.assertRaises(ValueError): with obj_store.get('a'): pass
[docs] def test_construct(self): """ Test that construction only gets called when used """ IDENTIFIER = 'a' OBJECT = 'my string' # Use a list for a single number so we can get references to it times_constructed = [ 0, ] def construct(): times_constructed[0] += 1 return OBJECT obj_store = RefObjectStore() with obj_store.get(IDENTIFIER, construct): self.assertEqual(1, times_constructed[0]) with obj_store.get(IDENTIFIER, construct): self.assertEqual(1, times_constructed[0]) # Now the object should be removed and so another call to get # should create with obj_store.get(IDENTIFIER, construct): self.assertEqual(2, times_constructed[0])