Source code for

# Copyright (c), The AiiDA team. All rights reserved.                     #
# This file is part of the AiiDA code.                                    #
#                                                                         #
# The code is hosted on GitHub at #
# For further information on the license, see the LICENSE.txt file        #
# For further information please visit               #
"""SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`."""
import functools
import gc
import pathlib
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError
from aiida.common.log import AIIDA_LOGGER
from aiida.manage.configuration.profile import Profile
from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import BackendEntity, StorageBackend
from import STORAGE_LOGGER
from import REPOSITORY_UUID_KEY, PsqlDosMigrator
from import base

from .orm import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users

    from aiida.repository.backend import DiskObjectStoreRepositoryBackend

__all__ = ('PsqlDosBackend',)

LOGGER = AIIDA_LOGGER.getChild(__file__)
    'pack_size_target': 4 * 1024 * 1024 * 1024,
    'loose_prefix_len': 2,
    'hash_type': 'sha256',
    'compression_algorithm': 'zlib+1',

[docs] def get_filepath_container(profile: Profile) -> pathlib.Path: """Return the filepath of the disk-object store container.""" from urllib.parse import urlparse try: parts = urlparse(profile.storage_config['repository_uri']) except KeyError: raise KeyError(f'invalid profile {}: `repository_uri` not defined in `storage.config`.') if parts.scheme != 'file': raise ConfigurationError( f'invalid profile {}: `storage.config.repository_uri` does not start with `file://`.' ) filepath = pathlib.Path(parts.path) if not filepath.is_absolute(): raise ConfigurationError(f'invalid profile {}: `storage.config.repository_uri` is not absolute') return filepath.expanduser() / 'container'
[docs] class PsqlDosBackend(StorageBackend): """An AiiDA storage backend that stores data in a PostgreSQL database and disk-objectstore repository. Note, there were originally two such backends, `sqlalchemy` and `django`. The `django` backend was removed, to consolidate access to this storage. """
[docs] class Configuration(BaseModel): """Model describing required information to configure an instance of the storage.""" database_engine: str = Field( title='PostgreSQL engine', description='The engine to use to connect to the database.', default='postgresql_psycopg2', ) database_hostname: str = Field( title='PostgreSQL hostname', description='The hostname of the PostgreSQL server.', default='localhost' ) database_port: int = Field( title='PostgreSQL port', description='The port of the PostgreSQL server.', default=5432 ) database_username: str = Field( title='PostgreSQL username', description='The username with which to connect to the PostgreSQL server.' ) database_password: str = Field( title='PostgreSQL password', description='The password with which to connect to the PostgreSQL server.' ) database_name: str = Field( title='PostgreSQL database name', description='The name of the database in the PostgreSQL server.' )
migrator = PsqlDosMigrator
[docs] @classmethod def version_head(cls) -> str: return cls.migrator.get_schema_version_head()
[docs] @classmethod def version_profile(cls, profile: Profile) -> Optional[str]: with cls.migrator_context(profile) as migrator: return migrator.get_schema_version_profile(check_legacy=True)
[docs] @classmethod def initialise(cls, profile: Profile, reset: bool = False) -> bool: with cls.migrator_context(profile) as migrator: return migrator.initialise(reset=reset)
[docs] @classmethod def migrate(cls, profile: Profile) -> None: with cls.migrator_context(profile) as migrator: migrator.migrate()
[docs] @classmethod @contextmanager def migrator_context(cls, profile: Profile): migrator = cls.migrator(profile) try: yield migrator finally: migrator.close()
[docs] def __init__(self, profile: Profile) -> None: super().__init__(profile) # check that the storage is reachable and at the correct version with self.migrator_context(profile) as migrator: migrator.validate_storage() self._session_factory: Optional[scoped_session] = None self._initialise_session() # save the URL of the database, for use in the __str__ method self._db_url = self.get_session().get_bind().url # type: ignore[union-attr] self._authinfos = authinfos.SqlaAuthInfoCollection(self) self._comments = comments.SqlaCommentCollection(self) self._computers = computers.SqlaComputerCollection(self) self._groups = groups.SqlaGroupCollection(self) self._logs = logs.SqlaLogCollection(self) self._nodes = nodes.SqlaNodeCollection(self) self._users = users.SqlaUserCollection(self)
@property def is_closed(self) -> bool: return self._session_factory is None
[docs] def __str__(self) -> str: state = 'closed' if self.is_closed else 'open' return f'Storage for {!r} [{state}] @ {self._db_url!r} / {self.get_repository()}'
[docs] def _initialise_session(self): """Initialise the SQLAlchemy session factory. Only one session factory is ever associated with a given class instance, i.e. once the instance is closed, it cannot be reopened. The session factory, returns a session that is bound to the current thread. Multi-thread support is currently required by the REST API. Although, in the future, we may want to move the multi-thread handling to higher in the AiiDA stack. """ from import create_sqlalchemy_engine engine = create_sqlalchemy_engine(self._profile.storage_config) # type: ignore[arg-type] self._session_factory = scoped_session(sessionmaker(bind=engine, future=True, expire_on_commit=True))
[docs] def get_session(self) -> Session: """Return an SQLAlchemy session bound to the current thread.""" if self._session_factory is None: raise ClosedStorage(str(self)) return self._session_factory()
[docs] def close(self) -> None: if self._session_factory is None: return # the instance is already closed, and so this is a no-op # close the connection engine = self._session_factory.bind if engine is not None: engine.dispose() # type: ignore[union-attr] self._session_factory.expunge_all() self._session_factory.close() self._session_factory = None # Without this, sqlalchemy keeps a weakref to a session # in sqlalchemy.orm.session._sessions gc.collect()
[docs] def _clear(self) -> None: from import DbSetting super()._clear() with self.migrator_context(self._profile) as migrator: # Close the session otherwise the ``delete_tables`` call will hang as there will be an open connection # to the PostgreSQL server and it will block the deletion and the command will hang. self.get_session().close() exclude_tables = [migrator.alembic_version_tbl_name, 'db_dbsetting'] migrator.delete_all_tables(exclude_tables=exclude_tables) # Clear out all references to database model instances which are now invalid. self.get_session().expunge_all() # Now reset and reinitialise the repository migrator.reset_repository() migrator.initialise_repository() repository_uuid = migrator.get_repository_uuid() with self.transaction() as session: session.execute( DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY).values(val=repository_uuid) )
[docs] def get_repository(self) -> 'DiskObjectStoreRepositoryBackend': from disk_objectstore import Container from aiida.repository.backend import DiskObjectStoreRepositoryBackend container = Container(get_filepath_container(self.profile)) return DiskObjectStoreRepositoryBackend(container=container)
@property def authinfos(self): return self._authinfos @property def comments(self): return self._comments @property def computers(self): return self._computers @property def groups(self): return self._groups @property def logs(self): return self._logs @property def nodes(self): return self._nodes
[docs] def query(self): return querybuilder.SqlaQueryBuilder(self)
@property def users(self): return self._users
[docs] @contextmanager def transaction(self) -> Iterator[Session]: """Open a transaction to be used as a context manager. If there is an exception within the context then the changes will be rolled back and the state will be as before entering. Transactions can be nested. """ session = self.get_session() if session.in_transaction(): with session.begin_nested() as savepoint: yield session savepoint.commit() session.commit() else: with session.begin(): with session.begin_nested() as savepoint: yield session savepoint.commit()
@property def in_transaction(self) -> bool: return self.get_session().in_nested_transaction()
[docs] @staticmethod @functools.lru_cache(maxsize=18) def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): """Return the Sqlalchemy mapper and fields corresponding to the given entity. :param with_pk: if True, the fields returned will include the primary key """ from sqlalchemy import inspect from import DbAuthInfo from import DbComment from import DbComputer from import DbGroup, DbGroupNode from import DbLog from import DbLink, DbNode from import DbUser model = { EntityTypes.AUTHINFO: DbAuthInfo, EntityTypes.COMMENT: DbComment, EntityTypes.COMPUTER: DbComputer, EntityTypes.GROUP: DbGroup, EntityTypes.LOG: DbLog, EntityTypes.NODE: DbNode, EntityTypes.USER: DbUser, EntityTypes.LINK: DbLink, EntityTypes.GROUP_NODE: DbGroupNode, }[entity_type] mapper = inspect(model).mapper keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} return mapper, keys
[docs] def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: mapper, keys = self._get_mapper_from_entity(entity_type, False) if not rows: return [] if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO): for row in rows: row['_metadata'] = row.pop('metadata') if allow_defaults: for row in rows: if not keys.issuperset(row): raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') else: for row in rows: if set(row) != keys: raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see # # by contrast, in sqlite, bulk_insert is faster: session = self.get_session() with nullcontext() if self.in_transaction else self.transaction(): result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall() return [ for row in result]
[docs] def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: mapper, keys = self._get_mapper_from_entity(entity_type, True) if not rows: return None for row in rows: if 'id' not in row: raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") if not keys.issuperset(row): raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') session = self.get_session() with nullcontext() if self.in_transaction else self.transaction(): session.execute(update(mapper), rows)
[docs] def delete(self, delete_database_user: bool = False) -> None: """Delete the storage and all the data. :param delete_database_user: Also delete the database user. This is ``False`` by default because the user may be used by other databases. """ import shutil from aiida.manage.external.postgres import Postgres profile = self.profile config = profile.storage_config postgres = Postgres.from_profile(self.profile) repository = get_filepath_container(profile).parent if repository.exists(): shutil.rmtree(repository)'Deleted repository at `{repository}`.') if postgres.db_exists(config['database_name']): postgres.drop_db(config['database_name'])'Deleted database `{config["database_name"]}`.') if delete_database_user and postgres.dbuser_exists(config['database_username']): postgres.drop_dbuser(config['database_username'])'Deleted database user `{config["database_username"]}`.')
[docs] def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: from import DbGroupNode from import DbLink, DbNode if not self.in_transaction: raise AssertionError('Cannot delete nodes and links outside a transaction') session = self.get_session() # Delete the membership of these nodes to groups. session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))).delete( synchronize_session='fetch' ) # Delete the links coming out of the nodes marked for deletion. session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Delete the links pointing to the nodes marked for deletion. session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Delete the actual nodes session.query(DbNode).filter('fetch')
[docs] def get_backend_entity(self, model: base.Base) -> BackendEntity: """Return the backend entity that corresponds to the given Model instance :param model: the ORM model instance to promote to a backend instance :return: the backend entity corresponding to the given model """ return convert.get_backend_entity(model, self)
[docs] def set_global_variable( self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite=True ) -> None: from import DbSetting session = self.get_session() with nullcontext() if self.in_transaction else self.transaction(): if session.query(DbSetting).filter(DbSetting.key == key).count(): if overwrite: session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value)) else: raise ValueError(f'The setting {key} already exists') else: session.add(DbSetting(key=key, val=value, description=description or ''))
[docs] def get_global_variable(self, key: str) -> Union[None, str, int, float]: from import DbSetting session = self.get_session() with nullcontext() if self.in_transaction else self.transaction(): setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() if setting is None: raise KeyError(f'No setting found with key {key}') return setting.val
[docs] def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None: from aiida.manage.profile_access import ProfileAccessManager repository = self.get_repository() if full: maintenance_context = ProfileAccessManager(self._profile).lock else: maintenance_context = nullcontext # type: ignore[assignment] with maintenance_context(): unreferenced_objects = self.get_unreferenced_keyset()'Deleting {len(unreferenced_objects)} unreferenced objects ...') if not dry_run: repository.delete_objects(list(unreferenced_objects))'Starting repository-specific operations ...') repository.maintain(live=not full, dry_run=dry_run, **kwargs)
[docs] def get_unreferenced_keyset(self, check_consistency: bool = True) -> Set[str]: """Returns the keyset of objects that exist in the repository but are not tracked by AiiDA. This should be all the soft-deleted files. :param check_consistency: toggle for a check that raises if there are references in the database with no actual object in the underlying repository. :return: a set with all the objects in the underlying repository that are not referenced in the database. """ from aiida import orm'Obtaining unreferenced object keys ...') repository = self.get_repository() keyset_repository = set(repository.list_objects()) keyset_database = set(orm.Node.get_collection(self).iter_repo_keys()) if check_consistency: keyset_missing = keyset_database - keyset_repository if len(keyset_missing) > 0: raise RuntimeError( 'There are objects referenced in the database that are not present in the repository. Aborting!' ) return keyset_repository - keyset_database
[docs] def get_info(self, detailed: bool = False) -> dict: results = super().get_info(detailed=detailed) results['repository'] = self.get_repository().get_info(detailed) return results