# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
from sqlalchemy import ForeignKey, select, func, join, and_, case
from sqlalchemy.orm import (
relationship, backref, Query, mapper,
foreign, aliased
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.schema import Column, UniqueConstraint
from sqlalchemy.types import Integer, String, Boolean, DateTime, Text
# Specific to PGSQL. If needed to be agnostic
# http://docs.sqlalchemy.org/en/rel_0_9/core/custom_types.html?highlight=guid#backend-agnostic-guid-type
# Or maybe rely on sqlalchemy-utils UUID type
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy_utils.types.choice import ChoiceType
from aiida.utils import timezone
from aiida.backends.sqlalchemy.models.base import Base, _QueryProperty, _AiidaQuery
from aiida.backends.sqlalchemy.models.utils import uuid_func
from aiida.common import aiidalogger
from aiida.common.exceptions import DbContentError, MissingPluginError
from aiida.common.datastructures import calc_states, _sorted_datastates, sort_states
from aiida.backends.sqlalchemy.models.user import DbUser
from aiida.backends.sqlalchemy.models.computer import DbComputer
[docs]class DbCalcState(Base):
__tablename__ = "db_dbcalcstate"
id = Column(Integer, primary_key=True)
dbnode_id = Column(
Integer,
ForeignKey(
'db_dbnode.id', ondelete="CASCADE",
deferrable=True, initially="DEFERRED"
)
)
dbnode = relationship(
'DbNode', backref=backref('dbstates', passive_deletes=True),
)
# Note: this is suboptimal: calc_states is not sorted
# therefore the order is not the expected one. If we
# were to use the correct order here, we could directly sort
# without specifying a custom order. This is probably faster,
# but requires a schema migration at this point
state = Column(ChoiceType((_, _) for _ in calc_states), index=True)
time = Column(DateTime(timezone=True), default=timezone.now)
__table_args__ = (
UniqueConstraint('dbnode_id', 'state'),
)
[docs]class DbNode(Base):
__tablename__ = "db_dbnode"
aiida_query = _QueryProperty(_AiidaQuery)
id = Column(Integer, primary_key=True)
uuid = Column(UUID(as_uuid=True), default=uuid_func)
type = Column(String(255), index=True)
label = Column(String(255), index=True, nullable=True,
default="") # Does it make sense to be nullable and have a default?
description = Column(Text(), nullable=True, default="")
ctime = Column(DateTime(timezone=True), default=timezone.now)
mtime = Column(DateTime(timezone=True), default=timezone.now)
nodeversion = Column(Integer, default=1)
public = Column(Boolean, default=False)
attributes = Column(JSONB)
extras = Column(JSONB)
dbcomputer_id = Column(
Integer,
ForeignKey('db_dbcomputer.id', deferrable=True, initially="DEFERRED", ondelete="RESTRICT"),
nullable=True
)
# This should have the same ondelet behaviour as db_computer_id, right?
user_id = Column(
Integer,
ForeignKey(
'db_dbuser.id', deferrable=True, initially="DEFERRED", ondelete="restrict"
),
nullable=False
)
# TODO SP: The 'passive_deletes=all' argument here means that SQLAlchemy
# won't take care of automatic deleting in the DbLink table. This still
# isn't exactly the same behaviour than with Django. The solution to
# this is probably a ON DELETE inside the DB. On removing node with id=x,
# we would remove all link with x as an output.
######### RELATIONSSHIPS ################
dbcomputer = relationship(
'DbComputer',
backref=backref('dbnodes', passive_deletes='all', cascade='merge')
)
# User
user = relationship(
'DbUser',
backref=backref('dbnodes', passive_deletes='all', cascade='merge', )
)
# outputs via db_dblink table
outputs_q = relationship(
"DbNode", secondary="db_dblink",
primaryjoin="DbNode.id == DbLink.input_id",
secondaryjoin="DbNode.id == DbLink.output_id",
backref=backref("inputs_q", passive_deletes=True, lazy='dynamic'),
lazy='dynamic',
passive_deletes=True
)
def __init__(self, *args, **kwargs):
super(DbNode, self).__init__(*args, **kwargs)
if self.attributes is None:
self.attributes = dict()
if self.extras is None:
self.extras = dict()
@property
def outputs(self):
return self.outputs_q.all()
@property
def inputs(self):
return self.inputs_q.all()
# XXX repetition between django/sqlalchemy here.
[docs] def get_aiida_class(self):
"""
Return the corresponding aiida instance of class aiida.orm.Node or a
appropriate subclass.
"""
from aiida.common.old_pluginloader import from_type_to_pluginclassname
from aiida.orm.node import Node
from aiida.common.pluginloader import load_plugin_safe
try:
pluginclassname = from_type_to_pluginclassname(self.type)
except DbContentError:
raise DbContentError("The type name of node with pk= {} is "
"not valid: '{}'".format(self.pk, self.type))
PluginClass = load_plugin_safe(Node, 'aiida.orm', pluginclassname, self.type, self.pk)
return PluginClass(dbnode=self)
[docs] def get_simple_name(self, invalid_result=None):
"""
Return a string with the last part of the type name.
If the type is empty, use 'Node'.
If the type is invalid, return the content of the input variable
``invalid_result``.
:param invalid_result: The value to be returned if the node type is
not recognized.
"""
thistype = self.type
# Fix for base class
if thistype == "":
thistype = "node.Node."
if not thistype.endswith("."):
return invalid_result
else:
thistype = thistype[:-1] # Strip final dot
return thistype.rpartition('.')[2]
[docs] def set_attr(self, key, value):
DbNode._set_attr(self.attributes, key, value)
flag_modified(self, "attributes")
self.save()
[docs] def del_attr(self, key):
DbNode._del_attr(self.attributes, key)
flag_modified(self, "attributes")
self.save()
[docs] @staticmethod
def _set_attr(d, key, value):
if '.' in key:
raise ValueError("We don't know how to treat key with dot in it yet")
d[key] = value
[docs] @staticmethod
def _del_attr(d, key):
if '.' in key:
raise ValueError("We don't know how to treat key with dot in it yet")
if key not in d:
raise ValueError("Key {} does not exists".format(key))
del d[key]
@property
def pk(self):
return self.id
[docs] def __str__(self):
simplename = self.get_simple_name(invalid_result="Unknown")
# node pk + type
if self.label:
return "{} node [{}]: {}".format(simplename, self.pk, self.label)
else:
return "{} node [{}]".format(simplename, self.pk)
# User email
@hybrid_property
def user_email(self):
"""
Returns: the email of the user
"""
return self.user.email
@user_email.expression
def user_email(cls):
"""
Returns: the email of the user at a class level (i.e. in the database)
"""
return select([DbUser.email]).where(DbUser.id == cls.user_id).label(
'user_email')
# Computer name
@hybrid_property
def computer_name(self):
"""
Returns: the of the computer
"""
return self.dbcomputer.name
@computer_name.expression
def computer_name(cls):
"""
Returns: the name of the computer at a class level (i.e. in the
database)
"""
return select([DbComputer.name]).where(DbComputer.id ==
cls.dbcomputer_id).label(
'computer_name')
@hybrid_property
def state(self):
"""
Return the most recent state from DbCalcState
"""
if not self.id:
return None
all_states = DbCalcState.query.filter(DbCalcState.dbnode_id == self.id).all()
if all_states:
# return max((st.time, st.state) for st in all_states)[1]
return sort_states(((dbcalcstate.state, dbcalcstate.state.value)
for dbcalcstate in all_states),
use_key=True)[0]
else:
return None
@state.expression
def state(cls):
"""
Return the expression to get the 'latest' state from DbCalcState,
to be used in queries, where 'latest' is defined using the state order
defined in _sorted_datastates.
"""
# Sort first the latest states
whens = {
v: idx for idx, v
in enumerate(_sorted_datastates[::-1], start=1)}
custom_sort_order = case(value=DbCalcState.state,
whens=whens,
else_=100) # else: high value to put it at the bottom
# Add numerical state to string, to allow to sort them
states_with_num = select([
DbCalcState.id.label('id'),
DbCalcState.dbnode_id.label('dbnode_id'),
DbCalcState.state.label('state_string'),
custom_sort_order.label('num_state')
]).select_from(DbCalcState).alias()
# Get the most 'recent' state (using the state ordering, and the min function) for
# each calc
calc_state_num = select([
states_with_num.c.dbnode_id.label('dbnode_id'),
func.min(states_with_num.c.num_state).label('recent_state')
]).group_by(states_with_num.c.dbnode_id).alias()
# Join the most-recent-state table with the DbCalcState table
all_states_q = select([
DbCalcState.dbnode_id.label('dbnode_id'),
DbCalcState.state.label('state_string'),
calc_state_num.c.recent_state.label('recent_state'),
custom_sort_order.label('num_state'),
]).select_from( # DbCalcState).alias().join(
join(DbCalcState, calc_state_num, DbCalcState.dbnode_id == calc_state_num.c.dbnode_id)).alias()
# Get the association between each calc and only its corresponding most-recent-state row
subq = select([
all_states_q.c.dbnode_id.label('dbnode_id'),
all_states_q.c.state_string.label('state')
]).select_from(all_states_q).where(all_states_q.c.num_state == all_states_q.c.recent_state).alias()
# Final filtering for the actual query
return select([subq.c.state]). \
where(
subq.c.dbnode_id == cls.id,
). \
label('laststate')
[docs]class DbLink(Base):
__tablename__ = "db_dblink"
id = Column(Integer, primary_key=True)
input_id = Column(
Integer,
ForeignKey('db_dbnode.id', deferrable=True, initially="DEFERRED")
)
output_id = Column(
Integer,
ForeignKey(
'db_dbnode.id',
ondelete="CASCADE",
deferrable=True,
initially="DEFERRED"
)
)
input = relationship("DbNode", primaryjoin="DbLink.input_id == DbNode.id")
output = relationship("DbNode", primaryjoin="DbLink.output_id == DbNode.id")
label = Column(String(255), index=True, nullable=False)
type = Column(String(255))
# A calculation can have both a 'return' and a 'create' link to
# a single data output node, which would violate the unique constraint
# defined below, since the difference in link type is not considered.
# The distinction between the type of a 'create' and a 'return' link is not
# implemented at the moment, so the unique constraint is disabled.
__table_args__ = (
# I cannot add twice the same link
# I want unique labels among all inputs of a node
# UniqueConstraint('output_id', 'label'),
)
[docs] def __str__(self):
return "{} ({}) --> {} ({})".format(
self.input.get_simple_name(invalid_result="Unknown node"),
self.input.pk,
self.output.get_simple_name(invalid_result="Unknown node"),
self.output.pk
)