Source code for aiida.backends.tests.base_dataclasses

# -*- coding: utf-8 -*-
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit               #
import unittest
import operator

from aiida.backends.testbase import AiidaTestCase
from aiida.common.exceptions import ModificationNotAllowed
from aiida.orm import load_node
from import (
    NumericType, Float, Str, Bool, Int, get_true_node, get_false_node)
import as base

[docs]class TestList(AiidaTestCase):
[docs] def test_creation(self): l = base.List() self.assertEqual(len(l), 0) with self.assertRaises(IndexError): l[0]
[docs] def test_append(self): def do_checks(l): self.assertEqual(len(l), 1) self.assertEqual(l[0], 4) l = base.List() l.append(4) do_checks(l) # Try the same after storing l = base.List() l.append(4) do_checks(l)
[docs] def test_extend(self): lst = [1, 2, 3] def do_checks(l): self.assertEqual(len(l), len(lst)) # Do an element wise comparison for x, y in zip(lst, l): self.assertEqual(x, y) l = base.List() l.extend(lst) do_checks(l) # Further extend l.extend(lst) self.assertEqual(len(l), len(lst) * 2) # Do an element wise comparison for i in range(0, len(lst)): self.assertEqual(lst[i], l[i]) self.assertEqual(lst[i], l[i % len(lst)]) # Now try after strogin l = base.List() l.extend(lst) do_checks(l)
[docs] def test_mutability(self): l = base.List() l.append(5) # Test all mutable calls are now disallowed with self.assertRaises(ModificationNotAllowed): l.append(5) with self.assertRaises(ModificationNotAllowed): l.extend([5]) with self.assertRaises(ModificationNotAllowed): l.insert(0, 2) with self.assertRaises(ModificationNotAllowed): l.remove(0) with self.assertRaises(ModificationNotAllowed): l.pop() with self.assertRaises(ModificationNotAllowed): l.sort() with self.assertRaises(ModificationNotAllowed): l.reverse()
[docs]class TestFloat(AiidaTestCase):
[docs] def setUp(self): super(TestFloat, self).setUp() self.value = Float() self.all_types = [Int, Float, Bool, Str]
[docs] def test_create(self): a = Float() # Check that initial value is zero self.assertEqual(a.value, 0.0) f = Float(6.0) self.assertEqual(f.value, 6.) self.assertEqual(f, Float(6.0)) i = Int() self.assertEqual(i.value, 0) i = Int(6) self.assertEqual(i.value, 6) self.assertEqual(f, i) b = Bool() self.assertEqual(b.value, False) b = Bool(False) self.assertEqual(b.value, False) self.assertEqual(b.value, get_false_node()) b = Bool(True) self.assertEqual(b.value, True) self.assertEqual(b.value, get_true_node()) s = Str() self.assertEqual(s.value, "") s = Str('Hello') self.assertEqual(s.value, 'Hello')
[docs] def test_load(self): for t in self.all_types: node = t() loaded = load_node( self.assertEqual(node, loaded)
[docs] def test_add(self): a = Float(4) b = Float(5) # Check adding two db Floats res = a + b self.assertIsInstance(res, NumericType) self.assertEqual(res, 9.0) # Check adding db Float and native (both ways) res = a + 5.0 self.assertIsInstance(res, NumericType) self.assertEqual(res, 9.0) res = 5.0 + a self.assertIsInstance(res, NumericType) self.assertEqual(res, 9.0) # Inplace a = Float(4) a += b self.assertEqual(a, 9.0) a = Float(4) a += 5 self.assertEqual(a, 9.0)
[docs] def test_mul(self): a = Float(4) b = Float(5) # Check adding two db Floats res = a * b self.assertIsInstance(res, NumericType) self.assertEqual(res, 20.0) # Check adding db Float and native (both ways) res = a * 5.0 self.assertIsInstance(res, NumericType) self.assertEqual(res, 20) res = 5.0 * a self.assertIsInstance(res, NumericType) self.assertEqual(res, 20.0) # Inplace a = Float(4) a *= b self.assertEqual(a, 20) a = Float(4) a *= 5 self.assertEqual(a, 20)
[docs] def test_power(self): a = Float(4) b = Float(2) res = a ** b self.assertEqual(res.value, 16.)
[docs]class TestFloatIntMix(AiidaTestCase):
[docs] def test_operator(self): a = Float(2.2) b = Int(3) for op in [operator.add, operator.mul, operator.pow,, operator.le,,, operator.iadd, operator.imul]: for x, y in [(a, b), (b, a)]: c = op(x, y) c_val = op(x.value, y.value) self.assertEqual(c._type, type(c_val)) self.assertEqual(c, op(x.value, y.value))