# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
import inspect
import collections
import uuid
from enum import Enum
import itertools
import plum.port as port
import plum.process
from plum.process_monitor import MONITOR
import plum.process_monitor
import voluptuous
from abc import ABCMeta
from aiida.common.extendeddicts import FixedFieldsAttributeDict
import aiida.common.exceptions as exceptions
from aiida.common.lang import override, protected
from aiida.common.links import LinkType
from aiida.utils.calculation import add_source_info
from aiida.work.defaults import class_loader
import aiida.work.util
from aiida.work.util import PROCESS_LABEL_ATTR, get_or_create_output_group
from aiida.orm.calculation import Calculation
from aiida.orm.data.parameter import ParameterData
from aiida.orm.calculation.work import WorkCalculation
from aiida import LOG_LEVEL_REPORT
class DictSchema(object):
def __init__(self, schema):
self._schema = voluptuous.Schema(schema)
def __call__(self, value):
"""
Call this to validate the value against the schema.
:param value: a regular dictionary or a ParameterData instance
:return: tuple (success, msg). success is True if the value is valid
and False otherwise, in which case msg will contain information about
the validation failure.
:rtype: tuple
"""
try:
if isinstance(value, ParameterData):
value = value.get_dict()
self._schema(value)
return True, None
except voluptuous.Invalid as e:
return False, str(e)
def get_template(self):
return self._get_template(self._schema.schema)
def _get_template(self, dict):
template = type(
"{}Inputs".format(self.__class__.__name__),
(FixedFieldsAttributeDict,),
{'_valid_fields': dict.keys()})()
for key, value in dict.iteritems():
if isinstance(key, (voluptuous.Optional, voluptuous.Required)):
if key.default is not voluptuous.UNDEFINED:
template[key.schema] = key.default
else:
template[key.schema] = None
if isinstance(value, collections.Mapping):
template[key] = self._get_template(value)
return template
class ProcessSpec(plum.process.ProcessSpec):
def __init__(self):
super(ProcessSpec, self).__init__()
self._fastforwardable = False
def is_fastforwardable(self):
return self._fastforwardable
def fastforwardable(self):
self._fastforwardable = True
def get_inputs_template(self):
"""
Get an object that represents a template of the known inputs and their
defaults for the :class:`Process`.
:return: An object with attributes that represent the known inputs for
this process. Default values will be filled in.
"""
template = type(
"{}Inputs".format(self.__class__.__name__),
(FixedFieldsAttributeDict,),
{'_valid_fields': self.inputs.keys()})()
# Now fill in any default values
for name, value_spec in self.inputs.iteritems():
if isinstance(value_spec.validator, DictSchema):
template[name] = value_spec.validator.get_template()
elif value_spec.default is not None:
template[name] = value_spec.default
else:
template[name] = None
return template
[docs]class Process(plum.process.Process):
"""
This class represents an AiiDA process which can be executed and will
have full provenance saved in the database.
"""
__metaclass__ = ABCMeta
SINGLE_RETURN_LINKNAME = '_return'
[docs] class SaveKeys(Enum):
"""
Keys used to identify things in the saved instance state bundle.
"""
CALC_ID = 'calc_id'
PARENT_CALC_PID = 'parent_calc_pid'
@classmethod
def define(cls, spec):
import aiida.orm
super(Process, cls).define(spec)
spec.input("_store_provenance", valid_type=bool, default=True,
required=False)
spec.input("_description", valid_type=basestring, required=False)
spec.input("_label", valid_type=basestring, required=False)
spec.dynamic_input(valid_type=(aiida.orm.Data, aiida.orm.Calculation))
spec.dynamic_output(valid_type=aiida.orm.Data)
@classmethod
def get_inputs_template(cls):
return cls.spec().get_inputs_template()
@classmethod
def _create_default_exec_engine(cls):
from aiida.work.defaults import serial_engine
return serial_engine
@classmethod
[docs] def create_db_record(cls):
"""
Create a database calculation node that represents what happened in
this process.
:return:
"""
from aiida.orm.calculation.work import WorkCalculation
calc = WorkCalculation()
return calc
_spec_type = ProcessSpec
def __init__(self):
super(Process, self).__init__()
self._calc = None
self._parent_pid = None
@property
def calc(self):
return self._calc
@override
def save_instance_state(self, bundle):
super(Process, self).save_instance_state(bundle)
if self.inputs._store_provenance:
assert self.calc.is_stored
bundle[self.SaveKeys.CALC_ID.value] = self.pid
bundle.set_class_loader(class_loader)
def run_after_queueing(self, wait_on):
return self._run
def get_provenance_inputs_iterator(self):
return itertools.ifilter(lambda kv: not kv[0].startswith('_'),
self.inputs.iteritems())
@override
def out(self, output_port, value=None):
if value is None:
# In this case assume that output_port is the actual value and there
# is just one return value
return super(Process, self).out(self.SINGLE_RETURN_LINKNAME,
output_port)
else:
return super(Process, self).out(output_port, value)
# Messages #####################################################
@override
def on_create(self, pid, inputs, saved_instance_state):
from aiida.orm import load_node
super(Process, self).on_create(pid, inputs, saved_instance_state)
if saved_instance_state is None:
# Get the parent from the top of the process stack
try:
self._parent_pid = aiida.work.util.ProcessStack.top().pid
except IndexError:
pass
self._pid = self._create_and_setup_db_record()
else:
if self.SaveKeys.CALC_ID.value in saved_instance_state:
self._calc = load_node(saved_instance_state[self.SaveKeys.CALC_ID.value])
self._pid = self.calc.pk
else:
self._pid = self._create_and_setup_db_record()
if self.SaveKeys.PARENT_CALC_PID.value in saved_instance_state:
self._parent_pid = saved_instance_state[
self.SaveKeys.PARENT_CALC_PID.value]
if self._logger is None:
self.set_logger(self.calc.logger)
@override
def on_start(self):
super(Process, self).on_start()
aiida.work.util.ProcessStack.push(self)
@override
[docs] def on_finish(self):
"""
Called when a Process enters the FINISHED state at which point
we set the corresponding attribute of the workcalculation node
"""
super(Process, self).on_finish()
self.calc._set_attr(WorkCalculation.FINISHED_KEY, True)
@override
[docs] def on_destroy(self):
"""
Called when a Process enters the DESTROYED state which should be
the final process state and so we seal the calculation node
"""
super(Process, self).on_destroy()
if self.calc.has_finished():
try:
self.calc.seal()
except exceptions.ModificationNotAllowed:
pass
@override
def _on_output_emitted(self, output_port, value, dynamic):
"""
The process has emitted a value on the given output port.
:param output_port: The output port name the value was emitted on
:param value: The value emitted
:param dynamic: Was the output port a dynamic one (i.e. not known
beforehand?)
"""
from aiida.orm import Data
super(Process, self)._on_output_emitted(output_port, value, dynamic)
assert isinstance(value, Data), \
"Values outputted from process must be instances of AiiDA Data" \
"types. Got: {}".format(value.__class__)
if not value.is_stored:
value.add_link_from(self.calc, output_port, LinkType.CREATE)
if self.inputs._store_provenance:
value.store()
value.add_link_from(self.calc, output_port, LinkType.RETURN)
#################################################################
@override
def do_run(self):
# Exclude all private inputs
ins = {k: v for k, v in self.inputs.iteritems() if not k.startswith('_')}
return self._run(**ins)
@protected
def get_parent_calc(self):
from aiida.orm import load_node
# Can't get it if we don't know our parent
if self._parent_pid is None:
return None
# First try and get the process from the registry in case it is running
try:
return MONITOR.get_process(self._parent_pid).calc
except ValueError:
pass
# Ok, maybe the pid is actually a pk...
try:
return load_node(pk=self._parent_pid)
except exceptions.NotExistent:
pass
# Out of options
return None
@protected
[docs] def report(self, msg, *args, **kwargs):
"""
Log a message to the logger, which should get saved to the
database through the attached DbLogHandler. The class name and function
name of the caller are prepended to the given message
"""
message = '[{}|{}|{}]: {}'.format(self.calc.pk, self.__class__.__name__, inspect.stack()[1][3], msg)
self.logger.log(LOG_LEVEL_REPORT, message, *args, **kwargs)
# @override
# def create_input_args(self, inputs):
# parsed = super(Process, self).create_input_args(inputs)
# # Now remove any that have a leading underscore
# for name in parsed.keys():
# if name.startswith('_'):
# del parsed[name]
# return parsed
def _create_and_setup_db_record(self):
self._calc = self.create_db_record()
self._setup_db_record()
if self.inputs._store_provenance:
self.calc.store_all()
if self.calc.pk is not None:
return self.calc.pk
else:
return uuid.UUID(self.calc.uuid)
def _setup_db_record(self):
assert self.inputs is not None
assert not self.calc.is_sealed, \
"Calculation cannot be sealed when setting up the database record"
# Save the name of this process
self.calc._set_attr(PROCESS_LABEL_ATTR, self.__class__.__name__)
parent_calc = self.get_parent_calc()
# First get a dictionary of all the inputs to link, this is needed to
# deal with things like input groups
to_link = {}
for name, input in self.inputs.iteritems():
# Ignore all inputs starting with a leading underscore, and None inputs
if name.startswith('_') or input is None:
continue
if self.spec().has_input(name):
if isinstance(self.spec().get_input(name), port.InputGroupPort):
to_link.update(
{"{}_{}".format(name, k): v for k, v in
input.iteritems()})
else:
to_link[name] = input
else:
# It's not in the spec, so we better support dynamic inputs
assert self.spec().has_dynamic_input()
to_link[name] = input
for name, input in to_link.iteritems():
if isinstance(input, Calculation):
input = get_or_create_output_group(input)
if not input.is_stored:
# If the input isn't stored then assume our parent created it
if parent_calc:
input.add_link_from(parent_calc, "CREATE",
link_type=LinkType.CREATE)
if self.inputs._store_provenance:
input.store()
self.calc.add_link_from(input, name)
if parent_calc:
self.calc.add_link_from(parent_calc, "CALL",
link_type=LinkType.CALL)
self._add_description_and_label()
def _add_description_and_label(self):
if self.raw_inputs:
description = self.raw_inputs.get('_description', None)
if description is not None:
self._calc.description = description
label = self.raw_inputs.get('_label', None)
if label is not None:
self._calc.label = label
def _can_fast_forward(self, inputs):
return False
def _fast_forward(self):
node = None # Here we should find the old node
for k, v in node.get_output_dict():
self.out(k, v)
class FunctionProcess(Process):
_func_args = None
@staticmethod
def _func(*args, **kwargs):
"""
This is used internally to store the actual function that is being
wrapped and will be replaced by the build method.
"""
return {}
@staticmethod
def build(func, **kwargs):
"""
Build a Process from the given function. All function arguments will
be assigned as process inputs. If keyword arguments are specified then
these will also become inputs.
:param func: The function to build a process from
:param kwargs: Optional keyword arguments that will become additional
inputs to the process
:return: A Process class that represents the function
:rtype: :class:`Process`
"""
import inspect
from aiida.orm.data import Data
args, varargs, keywords, defaults = inspect.getargspec(func)
def _define(cls, spec):
super(FunctionProcess, cls).define(spec)
for i in range(len(args)):
default = None
if defaults and len(defaults) - len(args) + i >= 0:
default = defaults[i]
spec.input(args[i], valid_type=Data, default=default)
# Make sure to get rid of the argument from the keywords dict
kwargs.pop(args[i], None)
for k, v in kwargs.iteritems():
spec.input(k)
# If the function support kwargs then allow dynamic inputs,
# otherwise disallow
if keywords is not None:
spec.dynamic_input()
else:
spec.no_dynamic_input()
# We don't know what a function will return so keep it dynamic
spec.dynamic_output(valid_type=Data)
return type(func.__name__, (FunctionProcess,),
{'_func': staticmethod(func),
Process.define.__name__: classmethod(_define),
'_func_args': args})
@classmethod
def args_to_dict(cls, *args):
"""
Create an input dictionary (i.e. label: value) from supplied args.
:param args: The values to use
:return: A label: value dictionary
"""
assert (len(args) == len(cls._func_args))
return dict(zip(cls._func_args, args))
@override
def _setup_db_record(self):
super(FunctionProcess, self)._setup_db_record()
add_source_info(self.calc, self._func)
# Save the name of the function
self.calc._set_attr(PROCESS_LABEL_ATTR, self._func.__name__)
@override
def _run(self, **kwargs):
from aiida.orm.data import Data
args = []
for arg in self._func_args:
args.append(kwargs.pop(arg))
outs = self._func(*args, **kwargs)
if outs is not None:
if isinstance(outs, Data):
self.out(self.SINGLE_RETURN_LINKNAME, outs)
elif isinstance(outs, collections.Mapping):
for name, value in outs.iteritems():
self.out(name, value)
else:
raise TypeError(
"Workfunction returned unsupported type '{}'\n"
"Must be a Data type or a Mapping of string => Data".
format(outs.__class__))
class _ProcessFinaliser(plum.process_monitor.ProcessMonitorListener):
"""
Take care of finalising a process when it finishes either through successful
completion or because of a failure caused by an exception.
"""
def __init__(self):
MONITOR.add_monitor_listener(self)
@override
def on_monitored_process_destroying(self, process):
aiida.work.util.ProcessStack.pop(process)
@override
def on_monitored_process_failed(self, pid):
from aiida.orm import load_node
try:
calc_node = load_node(pk=pid)
except ValueError:
pass
else:
calc_node._set_attr(calc_node.FAILED_KEY, True)
calc_node.seal()
finally:
aiida.work.util.ProcessStack.pop(pid=pid)
# Have a global singleton to take care of finalising all processes
_finaliser = _ProcessFinaliser()