Source code for aiida.manage.caching

###########################################################################
# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit http://www.aiida.net               #
###########################################################################
"""Definition of caching mechanism and configuration for calculations."""
from __future__ import annotations

import keyword
import re
from collections import namedtuple
from contextlib import contextmanager, suppress
from enum import Enum

from aiida.common import exceptions
from aiida.common.lang import type_check
from aiida.manage.configuration import get_config_option
from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP, ENTRY_POINT_STRING_SEPARATOR

__all__ = ('get_use_cache', 'enable_caching', 'disable_caching')


[docs] class ConfigKeys(Enum): """Valid keys for caching configuration.""" DEFAULT = 'caching.default_enabled' ENABLED = 'caching.enabled_for' DISABLED = 'caching.disabled_for'
[docs] class _ContextCache: """Cache options, accounting for when in enable_caching or disable_caching contexts."""
[docs] def __init__(self): self._default_all = None self._enable = [] self._disable = []
[docs] def clear(self): """Clear caching overrides.""" self.__init__() # type: ignore[misc]
[docs] def enable_all(self): self._default_all = 'enable'
[docs] def disable_all(self): self._default_all = 'disable'
[docs] def enable(self, identifier: str): self._enable.append(identifier) with suppress(ValueError): self._disable.remove(identifier)
[docs] def disable(self, identifier: str): self._disable.append(identifier) with suppress(ValueError): self._enable.remove(identifier)
[docs] def get_options(self, strict: bool = False): """Return the options, applying any context overrides. :param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if it fails, an exception is raised. """ if self._default_all == 'disable': return False, [], [] if self._default_all == 'enable': return True, [], [] default = get_config_option(ConfigKeys.DEFAULT.value) enabled = get_config_option(ConfigKeys.ENABLED.value)[:] disabled = get_config_option(ConfigKeys.DISABLED.value)[:] for ident in self._disable: disabled.append(ident) with suppress(ValueError): enabled.remove(ident) for ident in self._enable: enabled.append(ident) with suppress(ValueError): disabled.remove(ident) # Check validity of enabled and disabled entries try: for identifier in enabled + disabled: _validate_identifier_pattern(identifier=identifier, strict=strict) except ValueError as exc: raise exceptions.ConfigurationError('Invalid identifier pattern in enable or disable list.') from exc return default, enabled, disabled
_CONTEXT_CACHE = _ContextCache()
[docs] @contextmanager def enable_caching(*, identifier: str | None = None, strict: bool = False): """Context manager to enable caching, either for a specific node class, or globally. .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. If not provided, caching is enabled for all classes. :param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if it fails, an exception is raised. :type identifier: str """ type_check(identifier, str, allow_none=True) if identifier is None: _CONTEXT_CACHE.enable_all() else: _validate_identifier_pattern(identifier=identifier, strict=strict) _CONTEXT_CACHE.enable(identifier) yield _CONTEXT_CACHE.clear()
[docs] @contextmanager def disable_caching(*, identifier: str | None = None, strict: bool = False): """Context manager to disable caching, either for a specific node class, or globally. .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. If not provided, caching is disabled for all classes. :param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if it fails, an exception is raised. :type identifier: str """ type_check(identifier, str, allow_none=True) if identifier is None: _CONTEXT_CACHE.disable_all() else: _validate_identifier_pattern(identifier=identifier, strict=strict) _CONTEXT_CACHE.disable(identifier) yield _CONTEXT_CACHE.clear()
[docs] def get_use_cache(*, identifier: str | None = None, strict: bool = False) -> bool: """Return whether the caching mechanism should be used for the given process type according to the configuration. :param identifier: Process type string of the node :param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if it fails, an exception is raised. :return: boolean, True if caching is enabled, False otherwise :raises: `~aiida.common.exceptions.ConfigurationError` if the configuration is invalid, either due to a general configuration error, or by defining the class both enabled and disabled """ type_check(identifier, str, allow_none=True) default, enabled, disabled = _CONTEXT_CACHE.get_options(strict=strict) if identifier is not None: type_check(identifier, str) enable_matches = [pattern for pattern in enabled if _match_wildcard(string=identifier, pattern=pattern)] disable_matches = [pattern for pattern in disabled if _match_wildcard(string=identifier, pattern=pattern)] if enable_matches and disable_matches: # If both enable and disable have matching identifier, we search for # the most specific one. This is determined by checking whether # all other patterns match the specific pattern. PatternWithResult = namedtuple('PatternWithResult', ['pattern', 'use_cache']) most_specific = [] for specific_pattern in enable_matches: if all( _match_wildcard(string=specific_pattern, pattern=other_pattern) for other_pattern in enable_matches + disable_matches ): most_specific.append(PatternWithResult(pattern=specific_pattern, use_cache=True)) for specific_pattern in disable_matches: if all( _match_wildcard(string=specific_pattern, pattern=other_pattern) for other_pattern in enable_matches + disable_matches ): most_specific.append(PatternWithResult(pattern=specific_pattern, use_cache=False)) if len(most_specific) > 1: raise exceptions.ConfigurationError( f'Invalid configuration: multiple matches for identifier `{identifier}`, but the most specific ' f'identifier is not unique. Candidates: {[match.pattern for match in most_specific]}' ) if not most_specific: raise exceptions.ConfigurationError( f'Invalid configuration: multiple matches for identifier `{identifier}`, but none of them is most ' 'specific.' ) return most_specific[0].use_cache if enable_matches: return True if disable_matches: return False return default
[docs] def _match_wildcard(*, string: str, pattern: str) -> bool: """Return whether a given name matches a pattern which can contain '*' wildcards. :param string: The string to check. :param pattern: The patter to match for. :returns: ``True`` if ``string`` matches the ``pattern``, ``False`` otherwise. """ regexp = '.*'.join(re.escape(part) for part in pattern.split('*')) return re.fullmatch(pattern=regexp, string=string) is not None
[docs] def _validate_identifier_pattern(*, identifier: str, strict: bool = False): """Validate an caching identifier pattern. The identifier (without wildcards) can have one of two forms: 1. <group_name><ENTRY_POINT_STRING_SEPARATOR><tail> where `group_name` is one of the keys in `ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP` and `tail` can be anything _except_ `ENTRY_POINT_STRING_SEPARATOR`. 2. a fully qualified Python name this is a colon-separated string, where each part satisfies `part.isidentifier() and not keyword.iskeyword(part)` This function checks if an identifier _with_ wildcards can possibly match one of these two forms. If it can not, a ``ValueError`` is raised. :param identifier: Process type string, or a pattern with '*' wildcard that matches it. :param strict: When set to ``True``, the function will actually try to resolve the identifier by loading it and if it fails, an exception is raised. :raises ValueError: If the identifier is an invalid identifier. :raises ValueError: If ``strict=True`` and the identifier cannot be successfully loaded. """ import importlib from aiida.common.exceptions import EntryPointError from aiida.plugins.entry_point import load_entry_point_from_string common_error_msg = f'Invalid identifier pattern `{identifier}`: ' assert ENTRY_POINT_STRING_SEPARATOR not in '.*' # The logic of this function depends on this # Check if it can be an entry point string if identifier.count(ENTRY_POINT_STRING_SEPARATOR) > 1: raise ValueError( f'{common_error_msg}Can contain at most one entry point string separator `{ENTRY_POINT_STRING_SEPARATOR}`' ) # If there is one separator, it must be an entry point string. # Check if the left hand side is a matching pattern if ENTRY_POINT_STRING_SEPARATOR in identifier: group_pattern, _ = identifier.split(ENTRY_POINT_STRING_SEPARATOR) if not any( _match_wildcard(string=group_name, pattern=group_pattern) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP ): raise ValueError( common_error_msg + f'Group name pattern `{group_pattern}` does not match any of the AiiDA entry point group names.' ) # If strict mode is enabled and the identifier is explicit, i.e., doesn't contain a wildcard, try to load it. if strict and '*' not in identifier: try: load_entry_point_from_string(identifier) except EntryPointError as exception: raise ValueError(common_error_msg + f'`{identifier}` cannot be loaded.') from exception # The group name pattern matches, and there are no further entry point string separators in the identifier, # hence it is a valid pattern. return # The separator might be swallowed in a wildcard, for example # aiida.* or aiida.calculations* if '*' in identifier: group_part, _ = identifier.split('*', 1) if any(group_name.startswith(group_part) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP): return # Finally, check if it could be a fully qualified Python name for identifier_part in identifier.split('.'): # If it contains a wildcard, we can not check for keywords. # Replacing all wildcards with a single letter must give an # identifier - this checks for invalid characters, and that it # does not start with a number. if '*' in identifier_part: if not identifier_part.replace('*', 'a').isidentifier(): raise ValueError( common_error_msg + f'Identifier part `{identifier_part}` can not match a fully qualified Python name.' ) else: if not identifier_part.isidentifier(): raise ValueError(f'{common_error_msg}`{identifier_part}` is not a valid Python identifier.') if keyword.iskeyword(identifier_part): raise ValueError(f'{common_error_msg}`{identifier_part}` is a reserved Python keyword.') if not strict: return # If there is no separator, it must be a fully qualified Python name. try: module_name = '.'.join(identifier.split('.')[:-1]) class_name = identifier.split('.')[-1] module = importlib.import_module(module_name) getattr(module, class_name) except (ModuleNotFoundError, AttributeError, IndexError) as exc: raise ValueError(common_error_msg + f'`{identifier}` cannot be imported.') from exc