###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module to manage the autogrouping functionality by ``verdi run``."""
from __future__ import annotations
import re
from aiida.common import exceptions, timezone
from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql
from aiida.orm import AutoGroup, QueryBuilder
from aiida.plugins.entry_point import get_entry_point_string_from_class
[docs]
class AutogroupManager:
"""Class to automatically add all newly stored ``Node``s to an ``AutoGroup`` (whilst enabled).
This class should not be instantiated directly, but rather accessed through the backend storage instance.
The auto-grouping is checked by the ``Node.store()`` method which, if ``is_to_be_grouped`` is true,
will store the node in the associated ``AutoGroup``.
The exclude/include lists are lists of strings like:
``aiida.data:core.int``, ``aiida.calculation:quantumespresso.pw``,
``aiida.data:core.array.%``, ...
i.e.: a string identifying the base class, followed by a colon and the path to the class
as accepted by CalculationFactory/DataFactory.
Each string can contain one or more wildcard characters ``%``;
in this case this is used in a ``like`` comparison with the QueryBuilder.
Note that in this case you have to remember that ``_`` means "any character"
in the QueryBuilder, and you need to escape it if you mean a literal underscore.
Only one of the two (between exclude and include) can be set.
If none of the two is set, everything is included.
"""
[docs]
def __init__(self, backend):
"""Initialize the manager for the storage backend."""
self._backend = backend
self._enabled = False
self._exclude: list[str] | None = None
self._include: list[str] | None = None
self._group_label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}"
self._group_label = None # Actual group label, set by `get_or_create_group`
@property
def is_enabled(self) -> bool:
"""Return whether auto-grouping is enabled."""
return self._enabled
[docs]
def enable(self) -> None:
"""Enable the auto-grouping."""
self._enabled = True
[docs]
def disable(self) -> None:
"""Disable the auto-grouping."""
self._enabled = False
[docs]
def get_exclude(self) -> list[str] | None:
"""Return the list of classes to exclude from autogrouping.
Returns ``None`` if no exclusion list has been set.
"""
return self._exclude
[docs]
def get_include(self) -> list[str] | None:
"""Return the list of classes to include in the autogrouping.
Returns ``None`` if no inclusion list has been set.
"""
return self._include
[docs]
def get_group_label_prefix(self) -> str:
"""Get the prefix of the label of the group.
If no group label prefix was set, it will set a default one by itself.
"""
return self._group_label_prefix
[docs]
@staticmethod
def validate(strings: list[str] | None):
"""Validate the list of strings passed to set_include and set_exclude."""
if strings is None:
return
valid_prefixes = {'aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data'}
for string in strings:
pieces = string.split(':')
if len(pieces) != 2:
raise exceptions.ValidationError(
f"'{string}' is not a valid include/exclude filter, must contain two parts split by a colon"
)
if pieces[0] not in valid_prefixes:
raise exceptions.ValidationError(
f"'{string}' has an invalid prefix, must be among: {sorted(valid_prefixes)}"
)
[docs]
def set_exclude(self, exclude: list[str] | str | None) -> None:
"""Set the list of classes to exclude in the autogrouping.
:param exclude: a list of valid entry point strings (might contain '%' to be used as
string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None``
to specify no include list.
"""
if isinstance(exclude, str):
exclude = [exclude]
self.validate(exclude)
if exclude is not None and self.get_include() is not None:
# It's ok to set None, both as a default, or to 'undo' the exclude list
raise exceptions.ValidationError('Cannot both specify exclude and include')
self._exclude = exclude
[docs]
def set_include(self, include: list[str] | str | None) -> None:
"""Set the list of classes to include in the autogrouping.
:param include: a list of valid entry point strings (might contain '%' to be used as
string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None``
to specify no include list.
"""
if isinstance(include, str):
include = [include]
self.validate(include)
if include is not None and self.get_exclude() is not None:
# It's ok to set None, both as a default, or to 'undo' the include list
raise exceptions.ValidationError('Cannot both specify exclude and include')
self._include = include
[docs]
def set_group_label_prefix(self, label_prefix: str | None) -> None:
"""Set the label of the group to be created (or use a default)."""
if label_prefix is None:
label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}"
if not isinstance(label_prefix, str):
raise exceptions.ValidationError('group label must be a string')
self._group_label_prefix = label_prefix
self._group_label = None # reset the actual group label
[docs]
@staticmethod
def _matches(string, filter_string):
"""Check if 'string' matches the 'filter_string' (used for include and exclude filters).
If 'filter_string' does not contain any % sign, perform an exact match.
Otherwise, match with a SQL-like query, where % means any character sequence,
and _ means a single character (these characters can be escaped with a backslash).
:param string: the string to match.
:param filter_string: the filter string.
"""
if '%' in filter_string:
regex_filter = get_regex_pattern_from_sql(filter_string)
return re.match(regex_filter, string) is not None
return string == filter_string
[docs]
def is_to_be_grouped(self, node) -> bool:
"""Return whether the given node is to be auto-grouped according to enable state and include/exclude lists."""
if not self._enabled:
return False
# strings, including possibly 'all'
include = self.get_include()
exclude = self.get_exclude()
if include is None and exclude is None:
# Include all classes by default if nothing is explicitly specified.
return True
# We should never be here, anyway - this should be catched by the `set_include/exclude` methods
assert include is None or exclude is None, "You cannot specify both an 'include' and an 'exclude' list"
entry_point_string = node.process_type
# If there is no `process_type` we are dealing with a `Data` node so we get the entry point from the class
if not entry_point_string:
entry_point_string = get_entry_point_string_from_class(node.__class__.__module__, node.__class__.__name__)
if include is not None:
# As soon as a filter string matches, we include the class
return any(self._matches(entry_point_string, filter_string) for filter_string in include)
# If we are here, exclude is not None
# include *only* in *none* of the filters match (that is, exclude as
# soon as any of the filters matches)
return not any(self._matches(entry_point_string, filter_string) for filter_string in (exclude or []))
[docs]
def get_or_create_group(self) -> AutoGroup:
"""Return the current `AutoGroup`, or create one if None has been set yet.
This function implements a somewhat complex logic that is however needed
to make sure that, even if `verdi run` is called at the same time multiple
times, e.g. in a for loop in bash, there is never the risk that two ``verdi run``
Unix processes try to create the same group, with the same label, ending
up in a crash of the code (see PR #3650).
Here, instead, we make sure that if this concurrency issue happens,
one of the two will get a IntegrityError from the DB, and then recover
trying to create a group with a different label (with a numeric suffix appended),
until it manages to create it.
"""
# When this function is called, if it is the first time, just generate
# a new group name (later on, after this ``if`` block`).
# In that case, we will later cache in ``self._group_label`` the group label,
# So the group with the same name can be returned quickly in future
# calls of this method.
if self._group_label is not None:
builder = QueryBuilder(backend=self._backend).append(AutoGroup, filters={'label': self._group_label})
results = [res[0] for res in builder.iterall()]
if results:
# If it is not empty, it should have only one result due to the uniqueness constraints
assert len(results) == 1, 'I got more than one autogroup with the same label!'
return results[0]
# There are no results: probably the group has been deleted.
# I continue as if it was not cached
self._group_label = None
label_prefix = self.get_group_label_prefix()
# Try to do a preliminary QB query to avoid to do too many try/except
# if many of the prefix_NUMBER groups already exist
queryb = QueryBuilder(self._backend).append(
AutoGroup,
filters={
'or': [
{'label': {'==': label_prefix}},
{'label': {'like': f"{escape_for_sql_like(f'{label_prefix}_')}%"}},
]
},
project='label',
)
existing_group_labels = [res[0][len(label_prefix) :] for res in queryb.all()]
existing_group_ints = []
for label in existing_group_labels:
if label == '':
# This is just the prefix without name - corresponds to counter = 0
existing_group_ints.append(0)
elif label.startswith('_'):
try:
existing_group_ints.append(int(label[1:]))
except ValueError:
# It's not an integer, so it will never collide - just ignore it
pass
if not existing_group_ints:
counter = 0
else:
counter = max(existing_group_ints) + 1
while True:
try:
label = label_prefix if counter == 0 else f'{label_prefix}_{counter}'
group = AutoGroup(backend=self._backend, label=label).store()
self._group_label = group.label
except exceptions.IntegrityError:
counter += 1
else:
break
return group